pythonpysparkpytestpython-unittestpython-unittest.mock

Python test to mock/patch to change internal function arguments, while still running function


I'm looking to mock (or patch) a function so that I can replace the arguments it receives. An example of what I want to do:

# my_module.my_submodule

from some_library import some_module as x

def do_thing(a, b=None):
    return a + x.random_number(b)
    

# my_module.my_other_submodule

from my_module.my_submodule

def do_more_complex_thing(a):

    # Note: we do not pass 'b' here
    return do_thing(a)


# test.py

from my_module.my_other_submodule import do_more_complex_thing

def test_do_more_complex_thing():

    # I want to test this function, but I need to make sure that 
    # when x.random_number(b) is called, it receives a particular argument.
    # Note: I do not want to mock the return from x.random_number, only the 
    # arguments it receives.
    assert do_more_complex_thing(1) == 50

More concretely, I have some code that calls rand(seed) from pyspark.sql.functions at some point (deeply nested beyond the function I am testing), and I need to override the value of the seed to ensure my tests are deterministic. Changing the function signatures to pass the seed through is not an option, so we need to mock this.

I've taken a look through the unittest.mock documentation and other answers on here, and most of them seem to focus on replacing the mocked function with either a different function, or mocking out the return values. In my case, I still want the function to run like it should, but I just need to change the arguments it receives.

I have tried patching the mocked function with a function where I have returned the target function with manually set arguments, but this ends up with recursion errors, so I am clearly doing something wrong. Closer to my real world example:

# my_module.spark

import pyspark.sql.functions as f

def do_spark_thing():
    ...
    a = 1  # not settable from method signature
    f.rand(a)

# test.py

import pyspark.sql.functions as f
from my_module.spark import do_spark_thing


def test_do_more_complex_thing():

    def _set_seed(*args, **kwargs):
        return f.rand(1)

    with patch('my_module.spark.f.rand', _set_seed):
        # Gives recursion error
        do_spark_thing()

I am also using pytest as my core test framework, if there are any better options there vs unittest. What is the recommended way forward for this case?


Solution

  • You can't change a variable inside a running function. (Ok, it could be done, but it would involve enabling code-tracing and a lot of complications).

    However, it is straightforward to change global variables where the function is running - just make an attribute assignment to its module.

    In this case it means that instead of changing b you want to change rand to be a function that will call the original rand and pass your desired value instead.

    This is usually referred to as "monkey patching", and Python's unittest.mock module contains the patch callable which provides some facilities for that.

    So, inside your test, you write something like:

    ...
    from unittest.mock import patch
    import my_module.spark
    fixed_b  = ...
    original_rand = my_module.spark.rand
    new_rand = lambda b: original_rand(fixed_b)
    with patch("my_module.spark.rand",  new_rand):
         # testing whatever you need with the fixed seed goes here
         ...
    # original "rand" is restored at end of the `with` block
    ...