pythonpython-3.xpython-asyncioteardownpytest-asyncio

Pytest asyncio, howto await in setup and teardown?


I'm using pytest-asyncio to test my asyncio based library.

I'm using the class-level approach, in which a TestClass has several TestMethods which are invoked by the test framework one after each other.

The setup method initializes the ClassUnderTest. The teardown method currently does nothing. However, I commented the intended functionality in the teardown.

What I would like to do, is to implement an async teardown and/or setup, so I can await some async clean-up code. Is this possible?

I didn't find something about this in the pytest-asyncio documentation, which is very brief. Therefore, I'm asking this question. Maybe someone has stumbled over a similar problem and found a way to do it anyway.

import asyncio
import random

import pytest


class ClassUnderTest:
    def __init__(self):
        self._queue = asyncio.Queue()
        self._task1 = None
        self._task2 = None

    async def start(self):
        self._task1 = asyncio.create_task(self.producer())
        self._task2 = asyncio.create_task(self.consumer())

    async def stop(self):
        self._task1.cancel()
        self._task2.cancel()
        return await asyncio.gather(self._task1, self._task2, return_exceptions = True)

    @property
    def tasks(self):
        return self._task1, self._task2

    async def producer(self):
        try:
            while True:
                if self._queue.qsize() < 10:
                    self._queue.put_nowait(random.randint(0, 10))

                await asyncio.sleep(50)

        except asyncio.CancelledError:
            print("Finito!")
            raise

    async def consumer(self):
        try:
            while True:
                if self._queue.qsize() > 0:
                    elem = self._queue.get_nowait()
                    print(elem)

                await asyncio.sleep(100)

        except asyncio.CancelledError:
            print("Finito!")
            raise

@pytest.mark.asyncio
class TestClass:
    """ Tests my asynio code """

    def setup_method(self):
        self._my_class_under_test = ClassUnderTest()

    def teardown_method(self):
        """
        if not tasks[0].cancelled() or not tasks[1].cancelled():
            await self._my_class_under_test.stop()
        """

    async def test_start(self):
        await self._my_class_under_test.start()
        tasks = self._my_class_under_test.tasks
        assert not tasks[0].cancelled()
        assert not tasks[1].cancelled()
        await self._my_class_under_test.stop()

    async def test_stop(self):
        await self._my_class_under_test.start()
        tasks = self._my_class_under_test.tasks
        return_values = await self._my_class_under_test.stop()
        assert tasks[0].cancelled()
        assert tasks[1].cancelled()
        assert isinstance(return_values[0], asyncio.CancelledError)
        assert isinstance(return_values[1], asyncio.CancelledError)

    async def test_producer(self):
        pass

    async def test_consumer(self):
        pass


if __name__ == "__main__":
    pytest.main([__file__])

Output:

/home/user/.config/JetBrains/PyCharm2023.2/scratches/asyncio_test_setup_teardown.py 
============================= test session starts ==============================
platform linux -- Python 3.10.13, pytest-7.4.2, pluggy-1.3.0
rootdir: /home/user/.config/JetBrains/PyCharm2023.2/scratches
plugins: timeout-2.1.0, asyncio-0.21.1
asyncio: mode=strict
collected 2 items

asyncio_test_setup_teardown.py ..                                        [100%]

============================== 2 passed in 0.01s ===============================

Process finished with exit code 0

Solution

  • Create a custom pytest fixture for 'ClassUnderTest'. This fixture will handle the setup and teardown of your 'ClassUnderTest' instance:

    import pytest
    import asyncio
    import random
    
    class ClassUnderTest:
        def __init__(self):
            self._queue = asyncio.Queue()
            self._task1 = None
            self._task2 = None
    
        async def start(self):
            self._task1 = asyncio.create_task(self.producer())
            self._task2 = asyncio.create_task(self.consumer())
    
        async def stop(self):
            self._task1.cancel()
            self._task2.cancel()
            return await asyncio.gather(self._task1, self._task2, return_exceptions=True)
    
        @property
        def tasks(self):
            return self._task1, self._task2
    
        async def producer(self):
            try:
                while True:
                    if self._queue.qsize() < 10:
                        self._queue.put_nowait(random.randint(0, 10))
    
                    await asyncio.sleep(50)
    
            except asyncio.CancelledError:
                print("Finito!")
                raise
    
        async def consumer(self):
            try:
                while True:
                    if self._queue.qsize() > 0:
                        elem = self._queue.get_nowait()
                        print(elem)
    
                    await asyncio.sleep(100)
    
            except asyncio.CancelledError:
                print("Finito!")
                raise
    
    # Define a custom pytest fixture for ClassUnderTest
    @pytest.fixture
    async def class_under_test():
        my_class_under_test = ClassUnderTest()
        yield my_class_under_test
        await my_class_under_test.stop()
    
    # Use the custom fixture in your test class
    @pytest.mark.asyncio
    class TestClass:
        """ Tests my asyncio code """
    
        @pytest.mark.asyncio
        async def test_start(self, class_under_test):
            await class_under_test.start()
            tasks = class_under_test.tasks
            assert not tasks[0].cancelled()
            assert not tasks[1].cancelled()
    
        @pytest.mark.asyncio
        async def test_stop(self, class_under_test):
            await class_under_test.start()
            tasks = class_under_test.tasks
            return_values = await class_under_test.stop()
            assert tasks[0].cancelled()
            assert tasks[1].cancelled()
            assert isinstance(return_values[0], asyncio.CancelledError)
            assert isinstance(return_values[1], asyncio.CancelledError)
    
        @pytest.mark.asyncio
        async def test_producer(self):
            pass
    
        @pytest.mark.asyncio
        async def test_consumer(self):
            pass