pythonpython-3.xmultiple-inheritancediamond-problem

Best way of solving diamond problem in Python with fields


Python solves the diamond problem well if there are no fields in the classes by linearizing the method resolution order. However, if the classes have fields then how do you call the super constructors? Consider:

class A:
    def __init__(self, a):
        self.a = a  # Should only be initialized once.

class B(A):
    def __init__(self, a, b):
        super().__init__(a)
        self.b = b

class C(A):
    def __init__(self, a, c, b=None):
        super().__init__(a)
        self.c = c

class D(C, B):
    def __init__(self, a, b, c):
        super().???  # What do you put in here.

For my use case I do actually have a solution, because b can't be None in the application and therefore the following largely works:

class A:
    def __init__(self, a):
        self.a = a  # Should only be initialized once.

class B(A):
    def __init__(self, a, b):
        assert b is not None  # Special case of `b` can't be `None`.
        super().__init__(a)
        self.b = b

class C(A):
    def __init__(self, a, c, b=None):  # Special init with default sentinel `b`.
        if b is None:
            super().__init__(a)  # Normally `C`'s super is `A`.
        else:
            super().__init__(a, b)  # From `D` though, `C`'s super is `B`.
        self.c = c

class D(C, B):  # Note order, `C`'s init is super init.
    def __init__(self, a, b, c):
        super().__init__(a, c, b)

def main():
    A('a')
    B('b', 1)
    C('c', 2)
    D('d', 3, 4)
    C('c2', 5, 6)  # TypeError: __init__() takes 2 positional arguments but 3 were given

This largely works for the special case of b can't be None, however it still has a problem if C's __init__ is called directly (see last line of above). Also you have to modify C for the multiple inheritance and you have to inherit in the order C, B.

==== Edit ===

Another possibility is to manually initialize each field (this is somewhat similar to how Scala handles fields under the covers).

class A0:
    def __init__(self, a):  # Special separate init of `a`.
        self._init_a(a)

    def _init_a(self, a):
        self.a = a


class B0(A0):
    def __init__(self, a, b):  # Special separate init of `b`.
        self._init_a(a)
        self._init_b(b)

    def _init_b(self, b):
        self.b = b


class C0(A0):
    def __init__(self, a, c):  # Special separate init of `c`.
        self._init_a(a)
        self._init_c(c)

    def _init_c(self, c):
        self.c = c

class D0(C0, B0):
    def __init__(self, a, b, c):  # Uses special separate inits of `a`, `b`, and `c`.
        self._init_a(a)
        self._init_b(b)
        self._init_c(c)

The disadvantage of this approach is that it is very non-standard, to the extent that PyCharm gives a warning about not calling super init.

==== End edit ===

Is there a better way?

Thanks in advance for any help, Howard.


Solution

  • There is an excellent article by @rhettinger https://rhettinger.wordpress.com/2011/05/26/super-considered-super/ that has the answer. The @rhettinger code is tricky code that uses matching of keyword parameters by name that are relevant to that particular class and passing the rest on to the super classes. See article for full explanation.

    Therefore the 'pythonic' way to code the diamond problem with fields is:

    class A1:
        def __init__(self, *, a):  # Will fails if anything other than arg `a` is passed on to it.
            self.a = a
    
        def __repr__(self):
            return f'{type(self).__name__}(**{self.__dict__})'
    
    
    class B1(A1):
        def __init__(self, *, b, **kwargs):  # Extracts arg `b` and passes others on to other supers.
            self.b = b
            super().__init__(**kwargs)
    
    
    class C1(A1):
        def __init__(self, *, c, **kwargs):  # Extracts arg `c` and passes others on to other supers.
            self.c = c
            super().__init__(**kwargs)
    
    
    class D1(C1, B1):  # Note order, C1's init is 1st super init.
        def __init__(self, **kwargs):  # Passes on to supers all arguments.
            super().__init__(**kwargs)
    
    
    def main1():
        print(A1(a='a'))
        print(B1(a='b', b=1))
        print(C1(a='c', c=2))
        print(D1(a='d', b=3, c=4))
    
    
    if __name__ == '__main__':
        main1()
    

    The above prints:

    A1(**{'a': 'a'})
    B1(**{'b': 1, 'a': 'b'})
    C1(**{'c': 2, 'a': 'c'})
    D1(**{'c': 4, 'b': 3, 'a': 'd'})