pythonpytest-mock

How can I "dry" up this pytest-mock?


I have some test methods that require a mock of a class. The below code works but I am repeating the same mock in several methods. Here is the code that is getting repeated. You can see it used below in the last 3 methods.

class PyWaves(object):
    def names(self):
        return test_class_self.mock_wave_names()

I tried moving it to the test class's __init__ but that does not work. What's the right way to avoid repeating this code in each method where I need the mock?

class TestExtractedXPSNameAdapter:

    # def __init__(self, mocker):
    #     class PyWaves(object):
    #         def names(self):
    #             return self.mock_wave_names()
    #     self.waves = PyWaves()


    def mock_wave_names(self): 
        return ['net_a', 'top.net_b', 'top.foo@bar@net1#foo@bar@inst@0_g', 'top.foo1@bar1@net2#foo1@bar1@inst1_d']
    
    def expected_index(self): 
        return {'net_a': 'net_a', 'top.net_b': 'top.net_b', 'foo@bar@inst@0_g': 'top.foo@bar@net1#foo@bar@inst@0_g', 'foo1@bar1@inst1_d': 'top.foo1@bar1@net2#foo1@bar1@inst1_d'}

    def test_gen_populated_name_guesses(self):
        adapter = ExtractedXPSNameAdapter()
        identifier = WaveformIdentifier(instance_name="my_inst", path="/i_macro/sub1/sub2", term_name="my_terminal")
        guesses = adapter.gen_name_guesses(identifier)
        expected_guesses = [
            "i_macro@sub1@sub2@my_inst@0_my_terminal",
            "i_macro@sub1@sub2@my_inst_my_terminal",
            ]
        assert guesses == expected_guesses

    def test_gen_empty_name_guesses(self):
        adapter = ExtractedXPSNameAdapter()
        identifier = WaveformIdentifier(instance_name="my_inst", path="/i_macro/sub1/sub2")
        guesses = adapter.gen_name_guesses(identifier)
        expected_guesses = []
        assert guesses == expected_guesses

    def test_build_index(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        adapter.build_index()
        assert adapter.index == self.expected_index()

    def test_waveform_name(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        identifier = WaveformIdentifier(instance_name='inst1', path='/foo1/bar1', term_name="d")
        name = adapter.waveform_name(identifier)
        assert name == 'top.foo1@bar1@net2#foo1@bar1@inst1_d'

    def test_info_from_identifier(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        identifier = WaveformIdentifier(instance_name='inst1', path='/foo1/bar1', term_name="d")
        expected = WaveformNameInfo(identifier=identifier, type=identifier.type, full_name='top.foo1@bar1@net2#foo1@bar1@inst1_d', path=identifier.path, terminal_name=identifier.term_name)
        assert adapter.info_from_identifier(identifier) == expected

Solution

  • Define PyWaves at the global scope, and inject the mock function (or the object, but since in this case you just want one method, why complicate it?):

    class PyWaves:
        def __init__(self, mock_wave_names_fn):
            self.mock_wave_names = mock_wave_names_fn
    
        def names(self):
            return self.mock_wave_names()
    

    and then you can replace this:

        def test_build_index(self, mocker):
            test_class_self = self
            class PyWaves(object):
                def names(self):
                    return test_class_self.mock_wave_names()
            waves = PyWaves()
            adapter = ExtractedXPSNameAdapter()
            adapter.set_wave_list(waves)
            adapter.build_index()
            assert adapter.index == self.expected_index()
    

    with a DRYer version:

        def test_build_index(self, mocker):
            adapter = ExtractedXPSNameAdapter()
            adapter.set_wave_list(PyWaves(self.mock_wave_names))
            adapter.build_index()
            assert adapter.index == self.expected_index()