I'm trying out a new (beta) 1.4 sqlalchemy and encountered difficulty when trying to port "Joining a Session into an External Transaction (such as for test suite)" recipe using async API and pytest
.
Firstly, I've tried converting unittest
example of zzzeek to pytest
, which works fine
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session")
def engine_fixture():
engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def session(engine_fixture):
conn = engine_fixture.connect()
trans = conn.begin()
session = Session(bind=conn)
def _fixture(session):
session.add_all([Thing(), Thing(), Thing()])
session.commit()
# load fixture data within the scope of the transaction
_fixture(session)
# start the session in a SAVEPOINT...
session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
session.close()
trans.rollback()
conn.close()
def _test_thing(session, extra_rollback=0):
rows = session.query(Thing).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
assert len(rows) == 6
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
assert len(rows) == 5
session.add(Thing())
rows = session.query(Thing).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
def test_thing_one_pytest(session):
# run zero rollbacks
_test_thing(session, 0)
def test_thing_two_pytest(session):
# run two extra rollbacks
_test_thing(session, 2)
then I tried out switching to asyncio
API using pytest-asyncio
version 0.14.0
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session", autouse=True)
def meta_migration():
# setup
sync_engine = create_engine(
"postgresql://postgres:changethis@db/app_test", echo=True
)
Base.metadata.drop_all(sync_engine)
Base.metadata.create_all(sync_engine)
yield sync_engine
# teardown
Base.metadata.drop_all(sync_engine)
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
# setup
engine = create_async_engine(
"postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
)
yield engine
@pytest.fixture(scope="function")
async def session(async_engine):
conn = await async_engine.connect()
trans = await conn.begin()
session = AsyncSession(bind=conn)
async def _fixture(session: AsyncSession):
session.add_all([Thing(), Thing(), Thing()])
await session.commit()
# load fixture data within the scope of the transaction
await _fixture(session)
# start the session in a SAVEPOINT...
await session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
# NOTE: no async listeners yet
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
await session.close()
await trans.rollback()
await conn.close()
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
await session.rollback()
# after rollbacks, still @ 3 rows
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
await session.commit()
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 5
session.add(Thing())
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
await session.rollback()
rows = (await session.execute(select(Thing))).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
@pytest.mark.asyncio
async def test_thing_one_pytest(session):
# run zero rollbacks
await _test_thing(session, 0)
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
await _test_thing(session, 2)
This however fails with "FAILED test_thing_two_pytest - assert 8 == 3"
, as the transaction rollback in the teardown
after the first test, doesn't restore to SAVEPOINT created in the setup
phase.
As my knowledge of sqlalchemy internals is not that great, I'm seeking help in setting this up, as it is crucial for my test-suite performance.
Can it be that missing async
event listeners and defining restart_savepoint
in terms of AsyncSession.sync_session
is not sufficient, and one simply has to wait for a stable release of 1.4 API?
Thanks!
It turned out to be a bug, reached out to the SA developers directly.
Note: there is API change, and one should use connection.begin_nested()
in favor of session.begin_nested()
, according to @zzzek:
The "legacy" pattern that you have above which uses "session.begin_nested()" to create the savepoint, this is not supported for the "future" style engine which asyncio uses. The new version uses the connection itself to recreate the savepoint inside the event.