pythonsqlalchemypython-asyncioasyncpgpytest-asyncio

Session in an External Transaction with an async engine


I'm trying out a new (beta) 1.4 sqlalchemy and encountered difficulty when trying to port "Joining a Session into an External Transaction (such as for test suite)" recipe using async API and pytest.

Firstly, I've tried converting unittest example of zzzeek to pytest, which works fine

import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session")
def engine_fixture():
    engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)

    yield engine

    Base.metadata.drop_all(engine)


@pytest.fixture
def session(engine_fixture):
    conn = engine_fixture.connect()
    trans = conn.begin()
    session = Session(bind=conn)

    def _fixture(session):
        session.add_all([Thing(), Thing(), Thing()])
        session.commit()

    # load fixture data within the scope of the transaction
    _fixture(session)

    # start the session in a SAVEPOINT...
    session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    @event.listens_for(session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    session.close()
    trans.rollback()
    conn.close()


def _test_thing(session, extra_rollback=0):

    rows = session.query(Thing).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        assert len(rows) == 6

        session.rollback()

    # after rollbacks, still @ 3 rows
    rows = session.query(Thing).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    session.commit()

    rows = session.query(Thing).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = session.query(Thing).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        session.rollback()

    rows = session.query(Thing).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


def test_thing_one_pytest(session):
    # run zero rollbacks
    _test_thing(session, 0)


def test_thing_two_pytest(session):
    # run two extra rollbacks
    _test_thing(session, 2)

then I tried out switching to asyncio API using pytest-asyncio version 0.14.0

import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session", autouse=True)
def meta_migration():
    # setup
    sync_engine = create_engine(
        "postgresql://postgres:changethis@db/app_test", echo=True
    )
    Base.metadata.drop_all(sync_engine)
    Base.metadata.create_all(sync_engine)

    yield sync_engine

    # teardown
    Base.metadata.drop_all(sync_engine)


@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
    # setup
    engine = create_async_engine(
        "postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
    )

    yield engine


@pytest.fixture(scope="function")
async def session(async_engine):
    conn = await async_engine.connect()
    trans = await conn.begin()
    session = AsyncSession(bind=conn)

    async def _fixture(session: AsyncSession):
        session.add_all([Thing(), Thing(), Thing()])
        await session.commit()

    # load fixture data within the scope of the transaction
    await _fixture(session)

    # start the session in a SAVEPOINT...
    await session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    # NOTE: no async listeners yet
    @event.listens_for(session.sync_session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    await session.close()
    await trans.rollback()
    await conn.close()


async def _test_thing(session: AsyncSession, extra_rollback=0):

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        assert len(rows) == 6

        await session.rollback()

    # after rollbacks, still @ 3 rows
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    await session.commit()

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        await session.rollback()

    rows = (await session.execute(select(Thing))).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


@pytest.mark.asyncio
async def test_thing_one_pytest(session):
    # run zero rollbacks
    await _test_thing(session, 0)


@pytest.mark.asyncio
async def test_thing_two_pytest(session):
    # run two extra rollbacks
    await _test_thing(session, 2)

This however fails with "FAILED test_thing_two_pytest - assert 8 == 3", as the transaction rollback in the teardown after the first test, doesn't restore to SAVEPOINT created in the setup phase.

As my knowledge of sqlalchemy internals is not that great, I'm seeking help in setting this up, as it is crucial for my test-suite performance.

Can it be that missing async event listeners and defining restart_savepoint in terms of AsyncSession.sync_session is not sufficient, and one simply has to wait for a stable release of 1.4 API?

Thanks!


Solution

  • It turned out to be a bug, reached out to the SA developers directly.

    Github Issue

    Fix

    Note: there is API change, and one should use connection.begin_nested() in favor of session.begin_nested(), according to @zzzek:

    The "legacy" pattern that you have above which uses "session.begin_nested()" to create the savepoint, this is not supported for the "future" style engine which asyncio uses. The new version uses the connection itself to recreate the savepoint inside the event.