I am currently working on extending a third-party code base. This code base unfortunately tightly couples its get_args
with every other function. get_args
is basically just a getter for a global object _ARGS
. Now, I'd like to modify the args for a single function call without actually modifying the global object itself.
To this end, I used unittest.mock.patch
to patch the get_args
function, and while it succeeds in patching it in my target function f
, it does not translate to functions called by f
if they are in other modules. The reason is, of course, that I only patch get_args
in the module with the called function f
.
Is it possible to mock every subsequent call to get_args
within my with
block?
My current approach might not be the best one to tackle this problem, so I'm also open to alternative solutions to this problem.
Minimum reproducible example:
My code main.py
:
from argparse import Namespace
import traceback
import unittest.mock
from mod0 import get_args
from mod1 import f1
class _MockCnt:
def __init__(self):
self._args = Namespace(**{
**get_args().__dict__,
'a': 'a',
})
def new_get_args(self):
print("new_get_args was called")
traceback.print_stack()
return self._args
def main():
with unittest.mock.patch('mod1.get_args', new=_MockCnt().new_get_args):
f1()
main()
Module mod0
:
from argparse import Namespace
_ARGS = None
def get_args():
global _ARGS
if _ARGS is None:
_ARGS = Namespace(a=1, b=2)
return _ARGS
Module mod1
:
from mod0 import get_args
from mod2 import f2
def f1():
args = get_args()
args.c = 3
print(f"[f1] Args: {args} (id {id(args)})")
f2()
Module mod2
:
from mod0 import get_args
def f2():
args = get_args()
print(f"[f2] Args: {args} (id {id(args)})")
Result:
new get_args was called
File "/tmp/main.py", line 26, in <module>
main()
File "/tmp/main.py", line 24, in main
f1()
File "/tmp/mod1.py", line 6, in f1
args = get_args()
File "/tmp/main.py", line 19, in new_get_args
traceback.print_stack()
[f1] Args: Namespace(a='a', b=2, c=3) (id 281472856422576)
[f2] Args: Namespace(a=1, b=2) (id 281472856399392)
What I need (leave out the stack traces):
new get_args was called
[f1] Args: Namespace(a='a', b=2, c=3) (id 281472856422576)
new get_args was called
[f2] Args: Namespace(a='a', b=2, c=3) (id 281472856422576)
I finally found a simple solution to automatically mock out the same function in multiple modules temporarily. sys.modules
returns a list of all loaded modules. Inside this list, I can find all modules that have a member get_args
(or any of its parents) and then mock them. Below the changes to the code above:
def main():
with mock_all_calls_to_fn(get_args, new=_MockCnt().new_get_args):
f1()
@contextlib.contextmanager
def mock_all_calls_to_fn(fn, *args, **kwargs) -> ContextManager[None]:
with contextlib.ExitStack() as ctx_stack:
for fn_name in find_all_names_of_fn(fn):
patch = unittest.mock.patch(fn_name, *args, **kwargs)
ctx_stack.enter_context(patch)
yield
def find_all_names_of_fn(fn, root_module_name: str = None) -> List[str]:
fn_name = fn.__name__
names = []
for name, mod in sys.modules.items():
if root_module_name and not name.startswith(root_module_name + _MOD_SEP):
continue
if module_contains_obj(mod, fn_name, fn):
full_name = _MOD_SEP.join((name, fn_name))
names.append(full_name)
return names
def module_contains_obj(module, obj_name: str, obj: Any) -> bool:
return getattr(module, obj_name, None) == obj
Running this with python main.py
gives the expected result:
new_get_args was called
File "/tmp/main.py", line 60, in <module>
main()
File "/tmp/main.py", line 58, in main
f1()
File "/tmp/mod1.py", line 6, in f1
args = get_args()
File "/tmp/main.py", line 52, in new_get_args
traceback.print_stack()
[f1] Args: Namespace(a='a', b=2, c=3) (id 281473015935856)
new_get_args was called
File "/tmp/main.py", line 60, in <module>
main()
File "/tmp/main.py", line 58, in main
f1()
File "/tmp/mod1.py", line 9, in f1
f2()
File "/tmp/mod2.py", line 5, in f2
args = get_args()
File "/tmp/main.py", line 52, in new_get_args
traceback.print_stack()
[f2] Args: Namespace(a='a', b=2, c=3) (id 281473015935856)
It does also work if we use get_args
as follows:
import mod0
def f2():
mod0.get_args()
sys.modules
will only list already imported modules. Usually, all required modules for running f1
in above examples are loaded long before f1
is called. However, this approach does not work for local and dynamic imports, such as
def f2():
from mod0 import get_args
get_args()
and
def f2():
mod0 = importlib.import_module("mod0")
mod0.get_args()