I want to do a unittest with pytest for this method:
class UserService:
@classmethod
async def get_user_by_login(cls, session: SessionDep, login: str) -> Optional[User]:
sql = select(User).where(User.login == login)
result = await session.exec(sql)
return result.first()
I have already this test:
@pytest.mark.asyncio
async def test_get_user_by_login(mocker):
# Arrange
mock_exec_result = AsyncMock(return_value=None)
mock_session = mocker.AsyncMock()
mock_session.exec.return_value = mock_exec_result
sql_request = select(User).where(User.login == LOGIN)
# Act
await UserService.get_user_by_login(mock_session, LOGIN)
# Assert
mock_session.exec.assert_called_once_with(sql_request)
mock_session.exec.assert_called_once()
mock_exec_result.first.assert_called_once_with()
But I want to test sql_request
variable with assert_called_once_with
and it obviously fail because these are not the same object instance...
See the logs:
# Assert
> mock_session.exec.assert_called_once_with(sql_request)
E AssertionError: expected call not found.
E Expected: exec(<sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478A0EFC0>)
E Actual: exec(<sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478884B00>)
E
E pytest introspection follows:
E
E Args:
E assert (<sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478884B00>,) == (<sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478A0EFC0>,)
E
E At index 0 diff: <sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478884B00> != <sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478A0EFC0>
E
E Full diff:
E (
E - <sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478A0EFC0>,
E ? ^ ---
E + <sqlmodel.sql._expression_select_cls.SelectOfScalar object at 0x0000028478884B00>,
E ? ^^^^
E )
You are not supposed to compare the instances of SQL expressions. You need to compare the SQL query itself. The parameter which was passed to the mock_session.exec
method is stored by the mock. You can extract the arguments of the coroutine mock and compare its value to your expected value.
from sqlalchemy.dialects import postgresql
sql_expression = mock_session.exec.await_args[0][0]
assert str(sql_expression.compile(dialect=postgresql.dialect())) == str(select(User).where(User.login == LOGIN).compile(dialect=postgresql.dialect()))
If the SQL expression is changed in future, the test will break. You must provide the appropriate dialect
to compile the SQL expression according to the database you use.