pythonsqlalchemyfastapihttpxpytest-asyncio

How to test a FastAPI route that retries a SQLAlchemy insert after a rollback?


I have a route where I want to retry an insert if it failed due to an IntegrityError. I’m trying to test it using pytest and httpx but I get an error when I reuse the session to retry the insert after the rollback of the previous one. It works fine if I test with curl.

I use Python 3.10 with the latest FastAPI (0.95) and SQLAlchemy (2.0). I have a tests setup based on this blog post that works well for other tests but not this one.

Here is a minimal reproducible example (I left out the imports to reduce the code):

database.py:

async_engine = create_async_engine(f"sqlite+aiosqlite:///:memory:")
async_session_maker = async_sessionmaker(bind=async_engine, class_=AsyncSession, expire_on_commit=False)

async def get_async_db_session():
    async with async_session_maker() as session:
        yield session

class Base(DeclarativeBase):
    pass

class Animal(Base):
    __tablename__ = "animals"
    id: Mapped[int] = mapped_column(Integer, primary_key=True)
    name: Mapped[str] = mapped_column(String, nullable=False, unique=True)

main.py:

app = FastAPI()

@app.post("/add")
async def root(session=Depends(get_async_db_session)):
    for name in ("Max", "Cody", "Robby"):
        session.add(Animal(name=name))
        try:
            await session.flush()
        except IntegrityError:
            await session.rollback()
            continue  # retry

        await session.commit()
        return name

    return None

tests.py:

# test setup based on https://dev.to/jbrocher/fastapi-testing-a-database-5ao5
@pytest.fixture(scope="session")
def event_loop():
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()

@pytest.fixture(scope="session")
async def db_engine():
    engine = create_async_engine("sqlite+aiosqlite:///:memory:")
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    yield engine

@pytest.fixture(scope="function")
async def db(db_engine):
    async with db_engine.connect() as connection:
        async with connection.begin() as transaction:
            db_session = AsyncSession(bind=connection)
            yield db_session

            await transaction.rollback()

@pytest.fixture(scope="function")
async def client(db):
    app.dependency_overrides[get_async_db_session] = lambda: db
    async with AsyncClient(app=app, base_url="http://test") as c:
        yield c

async def test_add(client):
    r = await client.post("/add")
    assert r.json() == "Max"

    r = await client.post("/add")
    assert r.json() == "Cody"

I run the tests with pytest --asyncio-mode=auto tests.py.

The test simulates two requests to the endpoint. The first one succeeds, but the second one fails with the following error:

Can't operate on closed transaction inside context manager. Please complete the context manager before emitting further commands.

The traceback points to the line with await session.flush() in main.py.

I don’t understand what I need to change in the tests setup (or the route?) to make this work.


Solution

  • The issue seems to lie in the dependency override: the same DB session is returned for all calls to the dependency, so all calls within a test use the same session rather than a fresh one each time. I tried with nested transactions to no avail.

    In the end, I changed the pytest fixtures to generate a new session each time, and instead of rolling back at the end of the test, I just close the database and re-create it for each test.

    Relevant part:

    @pytest.fixture()
    async def db_engine():
        engine = create_async_engine("sqlite+aiosqlite:///:memory:")
        async with engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)
        yield engine
    
    
    @pytest.fixture()
    async def client(db_engine):
        async def get_async_db_session_test():
            async with db_engine.connect() as connection:
                db_session = AsyncSession(bind=connection)
                yield db_session
    
        app.dependency_overrides[get_async_db_session] = get_async_db_session_test
    
        async with AsyncClient(app=app, base_url="http://test") as c:
            yield c