pythonmocking

How to mock a function in multiple modules


I am currently working on extending a third-party code base. This code base unfortunately tightly couples its get_args with every other function. get_args is basically just a getter for a global object _ARGS. Now, I'd like to modify the args for a single function call without actually modifying the global object itself.

To this end, I used unittest.mock.patch to patch the get_args function, and while it succeeds in patching it in my target function f, it does not translate to functions called by f if they are in other modules. The reason is, of course, that I only patch get_args in the module with the called function f.

Is it possible to mock every subsequent call to get_args within my with block?

My current approach might not be the best one to tackle this problem, so I'm also open to alternative solutions to this problem.

Minimum reproducible example:

My code main.py:

from argparse import Namespace
import traceback
import unittest.mock

from mod0 import get_args
from mod1 import f1


class _MockCnt:

    def __init__(self):
        self._args = Namespace(**{
            **get_args().__dict__,
            'a': 'a',
        })

    def new_get_args(self):
        print("new_get_args was called")
        traceback.print_stack()
        return self._args


def main():
    with unittest.mock.patch('mod1.get_args', new=_MockCnt().new_get_args):
        f1()

main()

Module mod0:

from argparse import Namespace

_ARGS = None


def get_args():
    global _ARGS
    if _ARGS is None:
        _ARGS = Namespace(a=1, b=2)
    return _ARGS

Module mod1:

from mod0 import get_args
from mod2 import f2


def f1():
    args = get_args()
    args.c = 3
    print(f"[f1] Args: {args} (id {id(args)})")
    f2()

Module mod2:

from mod0 import get_args


def f2():
    args = get_args()
    print(f"[f2] Args: {args} (id {id(args)})")

Result:

new get_args was called
  File "/tmp/main.py", line 26, in <module>
    main()
  File "/tmp/main.py", line 24, in main
    f1()
  File "/tmp/mod1.py", line 6, in f1
    args = get_args()
  File "/tmp/main.py", line 19, in new_get_args
    traceback.print_stack()
[f1] Args: Namespace(a='a', b=2, c=3) (id 281472856422576)
[f2] Args: Namespace(a=1, b=2) (id 281472856399392)

What I need (leave out the stack traces):

new get_args was called
[f1] Args: Namespace(a='a', b=2, c=3) (id 281472856422576)
new get_args was called
[f2] Args: Namespace(a='a', b=2, c=3) (id 281472856422576)

Solution

  • Solution

    I finally found a simple solution to automatically mock out the same function in multiple modules temporarily. sys.modules returns a list of all loaded modules. Inside this list, I can find all modules that have a member get_args (or any of its parents) and then mock them. Below the changes to the code above:

    def main():
        with mock_all_calls_to_fn(get_args, new=_MockCnt().new_get_args):
            f1()
    
    
    @contextlib.contextmanager
    def mock_all_calls_to_fn(fn, *args, **kwargs) -> ContextManager[None]:
        with contextlib.ExitStack() as ctx_stack:
            for fn_name in find_all_names_of_fn(fn):
                patch = unittest.mock.patch(fn_name, *args, **kwargs)
                ctx_stack.enter_context(patch)
            yield
    
    
    def find_all_names_of_fn(fn, root_module_name: str = None) -> List[str]:
        fn_name = fn.__name__
        names = []
        for name, mod in sys.modules.items():
            if root_module_name and not name.startswith(root_module_name + _MOD_SEP):
                continue
            if module_contains_obj(mod, fn_name, fn):
                full_name = _MOD_SEP.join((name, fn_name))
                names.append(full_name)
        return names
    
    
    def module_contains_obj(module, obj_name: str, obj: Any) -> bool:
        return getattr(module, obj_name, None) == obj
    

    Result

    Running this with python main.py gives the expected result:

    new_get_args was called
      File "/tmp/main.py", line 60, in <module>
        main()
      File "/tmp/main.py", line 58, in main
        f1()
      File "/tmp/mod1.py", line 6, in f1
        args = get_args()
      File "/tmp/main.py", line 52, in new_get_args
        traceback.print_stack()
    [f1] Args: Namespace(a='a', b=2, c=3) (id 281473015935856)
    new_get_args was called
      File "/tmp/main.py", line 60, in <module>
        main()
      File "/tmp/main.py", line 58, in main
        f1()
      File "/tmp/mod1.py", line 9, in f1
        f2()
      File "/tmp/mod2.py", line 5, in f2
        args = get_args()
      File "/tmp/main.py", line 52, in new_get_args
        traceback.print_stack()
    [f2] Args: Namespace(a='a', b=2, c=3) (id 281473015935856)
    

    It does also work if we use get_args as follows:

    import mod0
    
    
    def f2():
        mod0.get_args()
    

    Limitations

    sys.modules will only list already imported modules. Usually, all required modules for running f1 in above examples are loaded long before f1 is called. However, this approach does not work for local and dynamic imports, such as

    def f2():
        from mod0 import get_args
        get_args()
    

    and

    def f2():
        mod0 = importlib.import_module("mod0")
        mod0.get_args()