I have started writing tests for my FastAPI/SQLAlchemy app and I would like to use a separate empty database for tests.
I added an override in my conftest.py file but the function override_get_db() never gets called. As a result, tests are run on the production database and cannot get them to run on the testing database. Any idea of what is wrong in my code ?
main.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes.address import router as address_router
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(address_router)
database.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from models.base import Base
from config import Config
from sqlalchemy.orm import Session
engine = create_engine(
Config.DATABASE_URI,
echo=True,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
print(f"Connecting to database: {Config.DATABASE_URI}")
Base.metadata.create_all(engine)
db: Session = SessionLocal()
try:
yield db
finally:
db.close()
routes/address.py
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from crud.address import get, get_all, create, update, delete
from database.database import get_db
from schemas.address import AddressCreate
router = APIRouter()
@router.get("/address/{address_id}")
async def get_address(address_id: int, db: Session = Depends(get_db)):
return get(db, address_id)
@router.get("/address/")
async def get_all_addresss(db: Session = Depends(get_db)):
return get_all(db)
@router.post("/address/")
async def create_address(address: AddressCreate, db: Session = Depends(get_db)):
return create(db, address)
@router.put("/address/{address_id}")
async def update_address(
address_id: int, address: AddressCreate, db: Session = Depends(get_db)
):
return update(db, address_id, address)
@router.delete("/address/{address_id}")
async def delete_address(address_id: int, db: Session = Depends(get_db)):
return delete(db, address_id)
conftest.py
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import Engine, StaticPool, create_engine
from sqlalchemy.orm import sessionmaker
from main import app
from config import Config
from src.database.database import get_db
from src.models.base import Base
print("Loading conftest.py")
TEST_DATABASE_URI = "sqlite:///:memory:"
@pytest.fixture(scope="session")
def engine() -> Engine:
print(f"Using database URI: {Config.TEST_DATABASE_URI}")
return create_engine(
Config.TEST_DATABASE_URI,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=True,
)
@pytest.fixture(scope="function")
def test_db(engine):
Base.metadata.create_all(bind=engine)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def override_get_db():
def _override_get_db():
print("Using test database")
try:
yield test_db
finally:
test_db.close()
return _override_get_db
@pytest.fixture(scope="function")
def test_app(override_get_db):
print("Applying dependency override")
app.dependency_overrides[get_db] = override_get_db
yield app
print("Clearing dependency override")
app.dependency_overrides.clear()
@pytest.fixture(scope="function")
def client(test_app):
return TestClient(test_app)
test_address.py
def test_create_address(client):
response = client.post(
"/address/",
json={
"city": "Springfield",
"country": "USA",
},
)
assert response.status_code == 200
response_data = response.json()
assert response_data["city"] == "Springfield"
assert response_data["country"] == "USA"
assert "id" in response_data
There is a problem with the way you're trying to override get_db
:
app.dependency_overrides[get_db] = override_get_db
override_get_db
is a fixture. You can't use fixtures for dependencies.
There are many possible solutions. From what I see, currently you only need the database for your client, so you could add all logic in the client
fixture:
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine,StaticPool
from sqlalchemy.orm import sessionmaker
from main import app
from config import Config
from models.base import Base
from database.database import get_db
@pytest.fixture(scope="function")
def client():
engine = create_engine(
Config.TEST_DATABASE_URI,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=True
)
Base.metadata.create_all(bind=engine)
TestingSessionLocal= sessionmaker(autocommit=False,autoflush=False, bind=engine)
def override_get_db():
db =TestingSessionLocal()
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db]= override_get_db
with TestClient(app) as client:
yield client
Base.metadata.drop_all(bind=engine )
app.dependency_overrides.clear()
Or you could separate the logic instead.
Main point – don't use fixtures for overriding dependencies. Use regular functions instead.