pythoninheritancedecoratorfunctoolssetattr

Automatically add decorator to all inherited methods


I want in class B to automatically add the decorator _preCheck to all methods that have been inherited from class A. In the example b.double(5) is correctly called with the wrapper. I want to avoid to manually re-declare (override) the inherited methods in B but instead, automatically decorate them, so that on the call to b.add(1,2) also _preCheck wrapper is called. Side note:

class A(object):
    def __init__(self, name):
        self.name = name
    
    def add(self, a, b):
        return a + b

class B(A):
    def __init__(self, name, foo):
        super().__init__(name)
        self.foo = foo
        
    def _preCheck(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs) :
            print("preProcess", self.name)
            return func(self, *args, **kwargs)
        return wrapper
                
    @_preCheck
    def double(self, i):
        return i * 2
    
b = B('myInst', 'bar')
print(b.double(5))
print(b.add(1,2))

Based on How can I decorate all inherited methods in a subclass I thought a possible solutions might be to ad the following snippet into B's init method:

        for attr_name in A.__dict__:
            attr = getattr(self, attr_name)
            if callable(attr):
                setattr(self, attr_name, self._preCheck(attr))

However, I get the following error. I suspect the 2nd argument comes from the 'self'. .

TypeError: _preCheck() takes 1 positional argument but 2 were given

There exist solutions to similar problems where they either initialize the subclasses from within the base class : Add decorator to a method from inherited class? Apply a python decorator to all inheriting classes


Solution

  • Decorators need to be added the class itself not the instance:

    from functools import wraps
    
    class A(object):
        def __init__(self, name):
            self.name = name
        
        def add(self, a, b):
            return a + b
    
    class B(A):
        def __init__(self, name, foo):
            super().__init__(name)
            self.foo = foo
            
        def _preCheck(func):
            @wraps(func)
            def wrapper(self, *args, **kwargs) :
                print("preProcess", self.name)
                return func(self, *args, **kwargs)
            return wrapper
                    
        @_preCheck
        def double(self, i):
            return i * 2
    
    for attr_name in A.__dict__:
        if attr_name.startswith('__'): # skip magic methods
            continue
        print(f"Decorating: {attr_name}")
        attr = getattr(A, attr_name)
        if callable(attr):
            setattr(A, attr_name, B._preCheck(attr))
        
    b = B('myInst', 'bar')
    print(b.double(5))
    print(b.add(1,2))
    

    Out:

    Decorating: add
    preProcess myInst
    10
    preProcess myInst
    3