pythonasync-awaitsqlalchemypytestpython-asyncio

Correct usage of async session on pytest


Given below implementation, I try to test using an asynchronous session. My attempt goes in the following way:

models.py

from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection

class Paginator:
    def __init__(
        self, 
        conn: Union[Connection, AsyncConnection],
        query: str, 
        params: dict = None, 
        batch_size: int = 10
    ):
        self.conn =  conn
        self.query = query
        self.params = params
        self.batch_size = batch_size
        self.current_offset = 0
        self.total_count = None

    async def _get_total_count_async(self) -> int:
        """Fetch the total count of records asynchronously."""
        count_query = f"SELECT COUNT(*) FROM ({self.query}) as total"
        query=text(count_query).bindparams(**(self.params or {}))
        result = await self.conn.execute(query)
        return result.scalar()

test_models.py

@pytest.fixture(scope='function')
async def async_session():
    async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()

        yield session

        await session.rollback()

@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
    # Prepare the paginator
    paginator = Paginator(
        conn=session,
        query="SELECT * FROM test_table",
        batch_size=2
    )

    # Perform the total count query asynchronously
    total_count = await paginator._get_total_count_async()

    # Assertion to verify the result
    assert total_count == 0

When I run the command pytest, I obtained following error: AttributeError: 'async_generator' object has no attribute 'execute'. I am pretty sure, there is an easy way to do so, but I am unaware of it.


Solution

  • You should pass an instance of AsyncConnection to the Paginator class, but you're sending session itself directly.

    To solve the issue there are two possible approaches:

    1. Resolve the session to reach the AsyncConnection within the test function:
    @pytest.mark.asyncio
    async def test_get_total_count_async(async_session):
        async for conn in async_session:
            paginator = Paginator(
                conn=conn,
                query="SELECT * FROM test_table",
                batch_size=2
            )
            ...
    
    1. Uging pytest_asyncio PyPI package for the fixture:
    @pytest_asyncio.fixture
    async def async_session():
        async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
        async_session = sessionmaker(
            expire_on_commit=False,
            autocommit=False,
            autoflush=False,
            bind=async_engine,
            class_=AsyncSession,
        )
    
        async with async_session() as session:
            await session.begin()
            yield session
            await session.rollback()
    
    @pytest.mark.asyncio
    async def test_get_total_count_async(async_session):
        # Prepare the paginator
        paginator = Paginator(
            conn=session,
            query="SELECT * FROM test_table",
            batch_size=2
        )
    
        # Perform the total count query asynchronously
        total_count = await paginator._get_total_count_async()
    
        # Assertion to verify the result
        assert total_count == 0
    

    Here's a post regarding this issue.