pythonmockingpython-asynciofastapipytest-asyncio

FastAPI: Not able to use AsyncMock in "async with connection.execute()" context for aiosqlite (async sqlite3 library) while testing the endpoint


I am using Python 3.13.2

I have async with connection.execute(query) as cursor: block in my endpoint, which I want to mock (connection is an object generated by await aiosqlite.connection(':memory:').

This is the minimal code I have (endpoint, non-mock test, and mock-test). The non-mock test is passing, but the mock-test is giving me an error.

from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
import aiosqlite

app = FastAPI()

client = TestClient(app)

@app.post("/create-users")
async def add_users(users: list[dict[str, int | str]]) -> dict[str, int]:
    conn = await aiosqlite.connect(':memory:')
    conn.row_factory = aiosqlite.Row
    _ = await conn.execute('CREATE TABLE users (user_id INTEGER, first_name TEXT)')
    for item in users:
        _ = await conn.execute(
            "INSERT INTO users VALUES(?, ?)",
            (item['user_id'], item['first_name']),
        )
    max_user_id = 0
    async with conn.execute('SELECT MAX(user_id) FROM users') as cursor:
        row = await cursor.fetchone()
        if row is not None:
            max_user_id = row[0]
    await cursor.close()
    await conn.close()
    return {'max_user_id': max_user_id} # Getting {'max_user_id': 2}

def test_create_users_no_mock():
    payload = [
        {"user_id": 1, "first_name": "Alice"},
        {"user_id": 2, "first_name": "Bob"},
    ]
    response = client.post('/create-users', json=payload)
    assert response.status_code == 200
    data = response.json()
    assert data == {'max_user_id': 2}

@pytest.mark.asyncio
async def test_create_users_mock():
    payload = [
        {"user_id": 1, "first_name": "Alice"},
        {"user_id": 2, "first_name": "Bob"},
    ]
    expected: dict[str, int] = {"max_user_id": 2}

    cursor = AsyncMock()
    cursor.fetchone.return_value = AsyncMock(return_value=(2,))

    class AsyncContextMock:
        def __init__(self, return_value=None):
            self.return_value = return_value

        async def __aenter__(self):
            return self.return_value

        async def __aexit__(self, exc_type, exc, tb):
            return False

    async def fake_execute(*args, **kwargs):
        return AsyncContextMock(return_value=cursor)

    conn = AsyncMock()
    conn.execute.side_effect = fake_execute

    async def fake_connect(*args, **kwargs):
        return conn

    with patch('aiosqlite.connect', side_effect=fake_connect):
        response = client.post('/create-users', json=payload)
    data = response.json()
    assert response.status_code == 200
    assert data == expected

Upon execution, I am getting an error

TypeError: 'coroutine' object does not support the asynchronous context manager protocol

Here is the execution trace.

% pytest test_main2.py
============================================================================================ test session starts =============================================================================================
platform darwin -- Python 3.13.2, pytest-9.0.1, pluggy-1.6.0                                           
rootdir: /Users/amit_tendulkar/quest/experiment                                                        
plugins: mock-3.15.1, langsmith-0.4.11, anyio-4.10.0, asyncio-1.3.0        
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 2 items                                                                                                                                                                                            
                                                                                                       
test_main2.py .F                                                                                                                                                                                       [100%]
                                                   
================================================================================================== FAILURES ==================================================================================================
___________________________________________________________________________________________ test_create_users_mock ___________________________________________________________________________________________
                                                                                                       
    @pytest.mark.asyncio                                                                               
    async def test_create_users_mock():         
        payload = [                                                                                                                                                                                           
            {"user_id": 1, "first_name": "Alice"},                                                     
            {"user_id": 2, "first_name": "Bob"},                                                       
        ]                                                                                              
        expected: dict[str, int] = {"max_user_id": 2}
                                                                                                       
        cursor = AsyncMock()                                                                           
        cursor.fetchone.return_value = AsyncMock(return_value=(2,)) 
                                                                                                                                                                                                              
        class AsyncContextMock:
            def __init__(self, return_value=None):
                self.return_value = return_value                                                                                                                                                              
                                                                                                       
            async def __aenter__(self):                                                                
                return self.return_value
                                                                                                       
            async def __aexit__(self, exc_type, exc, tb):                              
                return False                                                                           
                                                                                                       
        async def fake_execute(*args, **kwargs):                                                       
            return AsyncContextMock(return_value=cursor)                                                                                                                                                      
                                     
        conn = AsyncMock()                                                                                                                                                                                    
        conn.execute.side_effect = fake_execute                                                        
                                                                                                       
        async def fake_connect(*args, **kwargs):                                                       
            return conn                                                                                
                                                                                                       
        with patch('aiosqlite.connect', side_effect=fake_connect):                                                                                                                                            
>           response = client.post('/create-users', json=payload)                                                                                                                                             
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                             
                                                   
.
.
.
output trimmed
.
.
.
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

users = [{'first_name': 'Alice', 'user_id': 1}, {'first_name': 'Bob', 'user_id': 2}]

    @app.post("/create-users")
    async def add_users(users: list[dict[str, int | str]]) -> dict[str, int]:
        conn = await aiosqlite.connect(':memory:')
        conn.row_factory = aiosqlite.Row
        _ = await conn.execute('CREATE TABLE users (user_id INTEGER, first_name TEXT)')
        for item in users:
            _ = await conn.execute(
                "INSERT INTO users VALUES(?, ?)",
                (item['user_id'], item['first_name']),
            )
        max_user_id = 0
>       async with conn.execute('SELECT MAX(user_id) FROM users') as cursor:
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: 'coroutine' object does not support the asynchronous context manager protocol

test_main2.py:22: TypeError
============================================================================================== warnings summary ==============================================================================================
test_main2.py::test_create_users_mock
  /Users/amit_tendulkar/quest/experiment/test_main2.py:22: RuntimeWarning: coroutine 'AsyncMockMixin._execute_mock_call' was never awaited
    async with conn.execute('SELECT MAX(user_id) FROM users') as cursor:
  Enable tracemalloc to get traceback where the object was allocated.
  See https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings for more info.

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================== short test summary info ===========================================================================================
FAILED test_main2.py::test_create_users_mock - TypeError: 'coroutine' object does not support the asynchronous context manager protocol
=================================================================================== 1 failed, 1 passed, 1 warning in 0.24s ===================================================================================

What am I doing wrong?

EDIT based on @dmitri-galkin's comment

Here is the updated function based on suggestions from @dmitri-galkin.

However, I am still getting the exact error,

Updated function,

@pytest.mark.asyncio
async def test_create_users_mock():
    payload = [
        {"user_id": 1, "first_name": "Alice"},
        {"user_id": 2, "first_name": "Bob"},
    ]
    expected: dict[str, int] = {"max_user_id": 2}

    cursor = AsyncMock()
    cursor.fetchone.return_value = (2,)

    class AsyncContextMock:
        def __init__(self, return_value=None):
            self.return_value = return_value

        async def __aenter__(self):
            return self.return_value

        async def __aexit__(self, exc_type, exc, tb):
            return False

        def __await__(self):
            yield self.return_value

    def fake_execute(*args, **kwargs):
        return AsyncContextMock(return_value=cursor)

    conn = AsyncMock()
    conn.execute.side_effect = fake_execute

    async def fake_connect(*args, **kwargs):
        return conn

    with patch('aiosqlite.connect', side_effect=fake_connect):
        response = client.post('/create-users', json=payload)
    data = response.json()
    assert response.status_code == 200
    assert data == expected

Execution got the error,

                                                                                                       
users = [{'first_name': 'Alice', 'user_id': 1}, {'first_name': 'Bob', 'user_id': 2}]                   
                                                   
    @app.post("/create-users")                                                                                                                                                                                
    async def add_users(users: list[dict[str, int | str]]) -> dict[str, int]:                          
        conn = await aiosqlite.connect(':memory:')                                                                                                                                                            
        conn.row_factory = aiosqlite.Row                                                                                                                                                                      
        _ = await conn.execute('CREATE TABLE users (user_id INTEGER, first_name TEXT)')                
        for item in users:                                                                             
            _ = await conn.execute(                                                                                                                                                                           
                "INSERT INTO users VALUES(?, ?)",                                                      
                (item['user_id'], item['first_name']),                                                                                                                                                        
            )                                                                                                                                                                                                 
        max_user_id = 0                                                                                                                                                                                       
>       async with conn.execute('SELECT MAX(user_id) FROM users') as cursor:                           
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                      
E       TypeError: 'coroutine' object does not support the asynchronous context manager protocol       
                                                                                                       
test_main2.py:26: TypeError                                                                                                                                                                                   
============================================================================================== warnings summary ==============================================================================================
test_main2.py::test_create_users_mock                                                                                                                                                                         
  /Users/amit_tendulkar/quest/experiment/test_main2.py:26: RuntimeWarning: coroutine 'AsyncMockMixin._execute_mock_call' was never awaited
    async with conn.execute('SELECT MAX(user_id) FROM users') as cursor:                               
  Enable tracemalloc to get traceback where the object was allocated.                                  
  See https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings for more info.  
                                                                                                       
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html                                                                                                                                       
========================================================================================== short test summary info ===========================================================================================
FAILED test_main2.py::test_create_users_mock - TypeError: 'coroutine' object does not support the asynchronous context manager protocol                                                                       
=================================================================================== 1 failed, 1 passed, 1 warning in 0.26s ===================================================================================


Solution

  • Somehow to AsyncContextMock's __aenter__ method wasn't get called during async with. So I added it separately.

    The following test case is now passing.

    @pytest.mark.asyncio
    async def test_create_users_mock():
        payload = [
            {"user_id": 1, "first_name": "Alice"},
            {"user_id": 2, "first_name": "Bob"},
        ]
        expected: dict[str, int] = {"max_user_id": 2}
    
        class AsyncContextMock(MagicMock):
            def __await__(self):
                async def _coro():
                    return 1
                return _coro().__await__()
    
        async def fake_close(*args, **kwargs):
            pass
    
        conn = MagicMock(name='conn')
        execute = AsyncContextMock(name='execute')
        cursor = MagicMock(name='cursor')
        async def fake_fetchone():
            return (2,)
        cursor.fetchone.side_effect = fake_fetchone
        cursor.close.side_effect = fake_close
        execute.__aenter__.return_value = cursor
        conn.execute.return_value = execute
        conn.close.side_effect = fake_close
    
        async def fake_connect(*args, **kwargs):
            return conn
    
        with patch('aiosqlite.connect', side_effect=fake_connect):
            response = client.post('/create-users', json=payload)
        data = response.json()
        assert response.status_code == 200
        assert data == expected