pythonpytestfastapi

Pytest : Overriding production database with test database


I have started writing tests for my FastAPI/SQLAlchemy app and I would like to use a separate empty database for tests.

I added an override in my conftest.py file but the function override_get_db() never gets called. As a result, tests are run on the production database and cannot get them to run on the testing database. Any idea of what is wrong in my code ?

main.py


from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes.address import router as address_router


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.include_router(address_router)

database.py

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from models.base import Base
from config import Config
from sqlalchemy.orm import Session


engine = create_engine(
    Config.DATABASE_URI,
    echo=True,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def get_db():
    print(f"Connecting to database: {Config.DATABASE_URI}")
    Base.metadata.create_all(engine)
    db: Session = SessionLocal()
    try:
        yield db
    finally:
        db.close()

routes/address.py

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from crud.address import get, get_all, create, update, delete
from database.database import get_db
from schemas.address import AddressCreate

router = APIRouter()


@router.get("/address/{address_id}")
async def get_address(address_id: int, db: Session = Depends(get_db)):
    return get(db, address_id)


@router.get("/address/")
async def get_all_addresss(db: Session = Depends(get_db)):
    return get_all(db)


@router.post("/address/")
async def create_address(address: AddressCreate, db: Session = Depends(get_db)):
    return create(db, address)


@router.put("/address/{address_id}")
async def update_address(
    address_id: int, address: AddressCreate, db: Session = Depends(get_db)
):
    return update(db, address_id, address)


@router.delete("/address/{address_id}")
async def delete_address(address_id: int, db: Session = Depends(get_db)):
    return delete(db, address_id)

conftest.py

import pytest
from fastapi.testclient import TestClient
from sqlalchemy import Engine, StaticPool, create_engine
from sqlalchemy.orm import sessionmaker
from main import app
from config import Config
from src.database.database import get_db
from src.models.base import Base

print("Loading conftest.py")

TEST_DATABASE_URI = "sqlite:///:memory:"


@pytest.fixture(scope="session")
def engine() -> Engine:
    print(f"Using database URI: {Config.TEST_DATABASE_URI}")
    return create_engine(
        Config.TEST_DATABASE_URI,
        connect_args={"check_same_thread": False},
        poolclass=StaticPool,
        echo=True,
    )


@pytest.fixture(scope="function")
def test_db(engine):
    Base.metadata.create_all(bind=engine)
    TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
    db = TestingSessionLocal()
    try:
        yield db
    finally:
        db.close()
        Base.metadata.drop_all(bind=engine)


@pytest.fixture(scope="function")
def override_get_db():
    def _override_get_db():
        print("Using test database")
        try:
            yield test_db
        finally:
            test_db.close()

    return _override_get_db


@pytest.fixture(scope="function")
def test_app(override_get_db):
    print("Applying dependency override")
    app.dependency_overrides[get_db] = override_get_db
    yield app
    print("Clearing dependency override")
    app.dependency_overrides.clear()


@pytest.fixture(scope="function")
def client(test_app):
    return TestClient(test_app)

test_address.py

def test_create_address(client):
    response = client.post(
        "/address/",
        json={
            "city": "Springfield",
            "country": "USA",
        },
    )
    assert response.status_code == 200
    response_data = response.json()
    assert response_data["city"] == "Springfield"
    assert response_data["country"] == "USA"
    assert "id" in response_data


Solution

  • There is a problem with the way you're trying to override get_db:

    app.dependency_overrides[get_db] = override_get_db
    

    override_get_db is a fixture. You can't use fixtures for dependencies.

    There are many possible solutions. From what I see, currently you only need the database for your client, so you could add all logic in the client fixture:

    import pytest
    from fastapi.testclient import TestClient
    from sqlalchemy import create_engine,StaticPool
    from sqlalchemy.orm import sessionmaker
    from main import app
    from config import Config
    from models.base import Base
    from database.database import get_db
    
    @pytest.fixture(scope="function")
    def client(): 
        engine = create_engine(
            Config.TEST_DATABASE_URI,
            connect_args={"check_same_thread": False},
            poolclass=StaticPool,
            echo=True
        )  
        Base.metadata.create_all(bind=engine)
        TestingSessionLocal= sessionmaker(autocommit=False,autoflush=False, bind=engine)
        
        def override_get_db():
            db =TestingSessionLocal()
            try:
                yield db
            finally:
                db.close()
        
        app.dependency_overrides[get_db]= override_get_db
        with TestClient(app) as client:
            yield client
    
        Base.metadata.drop_all(bind=engine )
        app.dependency_overrides.clear()
    

    Or you could separate the logic instead.

    Main point – don't use fixtures for overriding dependencies. Use regular functions instead.