I have the following statement in one of the methods under unit test.
db_employees = self.db._session.query(Employee).filter(Employee.dept ==
new_employee.dept).all()
I want db_employees to get mock list of employees. I tried to achieve this using:
m = MagickMock()
m.return_value.filter().all().return_value = employees
where employees is a list of employee object. But this did not work. When I try to print the value of any attribute, it has a mock value. This is how the code looks:
class Database(object):
def __init__(self, user=None, passwd=None, db="sqlite:////tmp/emp.db"):
try:
engine = create_engine(db)
except Exception:
raise ValueError("Database '%s' does not exist." % db)
def on_connect(conn, record):
conn.execute('pragma foreign_keys=ON')
if 'sqlite://' in db:
event.listen(engine, 'connect', on_connect)
Base.metadata.bind = engine
DBSession = sessionmaker(bind=engine)
self._session = DBSession()
class TestEmployee(MyEmployee):
def setUp(self):
self.db = emp.database.Database(db=options.connection)
self.db._session._autoflush()
@mock.patch.object(session.Session, 'add')
@mock.patch.object(session.Session, 'query')
def test_update(self, mock_query, mock_add):
employees = [{'id': 1,
'name': 'Pradeep',
'department': 'IT',
'manager': 'John'}]
mock_add.side_effect = self.add_side_effect
mock_query.return_value = self.query_results()
self.update_employees(employees)
def add_side_effect(self, instance, _warn=True):
// Code to mock add
// Values will be stored in a dict which will be used to
// check with expected value.
def query_results(self):
m = MagicMock()
if self.count == 0:
m.return_value.filter.return_value.all.return_value = [employee]
elif:
m.return_value.filter.return_value.all.return_value = [department]
return m
I have query_results as the method under test calls query twice. First the employee table and next the department table.
How do I mock this chained function call?
You should patch query()
method of _session
's Database
attribute and configure it to give you the right answer. You can do it in a lot of way, but IMHO the cleaner way is to patch DBSession
's query
static reference. I don't know from witch module you imported DBSession
so I'll patch the local reference.
The other aspect is the mock configuration: we will set query
's return value that in your case become the object that have filter()
method.
class TestEmployee(MyEmployee):
def setUp(self):
self.db = emp.database.Database(db=options.connection)
self.db._session._autoflush()
self.log_add = {}
@mock.patch.object(__name__.'DBSession.add')
@mock.patch.object(__name__.'DBSession.query')
def test_update(self, mock_query, mock_add):
employees = [{'id': 1,
'name': 'Pradeep',
'department': 'IT',
'manager': 'John'}]
mock_add.side_effect = self.add_side_effect
mock_query.return_value = self.query_results()
self.update_employees(employees)
.... your test here
def add_side_effect(self, instance, _warn=True):
# ... storing data
self.log_add[...] = [...]
def query_results(self):
m = MagicMock()
value = "[department]"
if not self.count:
value = "[employee]"
m.filter.return_value.all.return_value = value
return m