pythonsqlalchemyalembic

How to create a class that runs business logic upon a query?


I'd like to create a class/object that I can use for querying, that contains business logic.

Constraints:

How do I do that? Is that even possible?

Use Case

My database table has two columns: value_a and show_value_a. show_value_a specifies if the value is supposed to be shown on the UI or not. Currently, all processes that query value_a have to check if show_value_a is True; If not, the value of value_a will be masked (i.e. set to None) upon returning.

Masking the value is easy to forget. Also, each process has their own specific query (with their specific JOINs), so it's ineffective to do this in some kind of pattern form.

Example

Table definition:

from sqlalchemy import Column, String, Boolean

class MyTable(Base):
  __tablename__ = "mytable"

  valueA = Column("value_a", String(60), nullable=False)
  showValueA = Column("show_value_a", Boolean, nullable=False)

Data:

value_a show_value_a
"A" True
"B" False
"C" True

Query I'd like to do:

values = session.query(MyTable.valueA).all() 
 # returns  ["A", None, "C"]

Querying the field will intrinsically check if show_value_a is True. If it is, the value is returned. If not, None is returned


Solution

  • You can use an execute event to intercept queries and modify them before execution. This sample event

    1. Checks the session's info dictionary to determine whether the query relates to an entity of interest
    2. Creates a modified query that checks whether valueA can be shown
    3. Replaces the original query with the modified query
    @sa.event.listens_for(Session, 'do_orm_execute')
    def _do_orm_execute(orm_execute_state):
        if orm_execute_state.is_select:
            statement = orm_execute_state.statement
            col_descriptions = statement.column_descriptions
            if (
                col_descriptions[0]['entity']
                in orm_execute_state.session.info['check_entities']
            ):
                expr = sa.case((MyTable.showValueA, MyTable.valueA), else_=None).label(
                    'value_a'
                )
                columns = [
                    c if c.name != 'value_a' else expr for c in statement.inner_columns
                ]
                new_statement = sa.select(MyTable).from_statement(sa.select(*columns))
                orm_execute_state.statement = new_statement
    

    Note that this will only work for 2.0-style queries (or 1.4 with the future option set on engines and sessions). The code assumes a simple select(MyTable) query - you would need to add where criteria, order_by etc from the original query. Joins etc might also require some additional work.

    Here's a runnable example:

    import sqlalchemy as sa
    from sqlalchemy import orm
    from sqlalchemy.orm import Mapped, mapped_column
    
    class Base(orm.DeclarativeBase):
        pass
    
    
    class MyTable(Base):
        __tablename__ = 't79426130'
    
        id: Mapped[int] = mapped_column(primary_key=True)
        valueA: Mapped[str] = mapped_column('value_a')
        showValueA: Mapped[bool] = mapped_column('show_value_a')
    
    
    engine = sa.create_engine('sqlite://', echo=True)
    Base.metadata.create_all(engine)
    info = {'check_entities': {MyTable}}
    Session = orm.sessionmaker(engine, info=info)
    
    
    @sa.event.listens_for(Session, 'do_orm_execute')
    def _do_orm_execute(orm_execute_state):
        if orm_execute_state.is_select:
            statement = orm_execute_state.statement
            col_descriptions = statement.column_descriptions
            if (
                col_descriptions[0]['entity']
                in orm_execute_state.session.info['check_entities']
            ):
                expr = sa.case((MyTable.showValueA, MyTable.valueA), else_=None).label(
                    'value_a'
                )
                columns = [
                    c if c.name != 'value_a' else expr for c in statement.inner_columns
                ]
                new_statement = sa.select(MyTable).from_statement(sa.select(*columns))
                orm_execute_state.statement = new_statement
    
    
    with Session.begin() as s:
        mts = [MyTable(valueA=v, showValueA=s) for v, s in zip('ABC', [True, False, True])]
        s.add_all(mts)
    
    with Session() as s:
        for mt in s.scalars(sa.select(MyTable)):
            print(mt.valueA, mt.showValueA)