pythonlanguage-lawyeroperatorsmethod-resolution-order

Python MRO for operators: Chooses RHS `__rmul__` instead of LHS `__mul__` when RHS is a subclass


Consider the following self-contained example:

class Matrix:
    def __mul__(self, other):
        print("Matrix.__mul__(⋯)")
        return NotImplemented
    def __rmul__(self, other):
        print("Matrix.__rmul__(⋯)")
        return NotImplemented

class Vector(Matrix):
    def __mul__(self, other):
        print("Vector.__mul__(⋯)")
        return NotImplemented
    def __rmul__(self, other):
        print("Vector.__rmul__(⋯)")
        return NotImplemented

matr = Matrix()
vec  = Vector()

print("=== Using explicit `__mul__`: ===")
matr.__mul__(vec)
print()

print("=== Using implicit `*`: ===")
matr * vec

with output (CPython 3.12.2):

=== Using explicit `__mul__`: ===
Matrix.__mul__(⋯)

=== Using implicit `*`: ===
Vector.__rmul__(⋯)
Matrix.__mul__(⋯)
(TypeError raised: "unsupported operand type(s) for *: 'Matrix' and 'Vector'")

I'm trying to understand why, in the second case, Vector.__rmul__ gets called before Matrix.__mul__.

My understanding is that Python sees the *, then looks at the LHS and RHS. If it sees the LHS has a __mul__, it's called (with arguments self=LHS, right=RHS). If that returns NotImplemented, only then does it try the RHS's __rmul__ (with arguments self=RHS, right=LHS).

In particular, in this case, when it looks at the LHS, it should find Matrix.__mul__, and only when that fails should it then try Vector.__rmul__. But it's doing it the other way around! Why?

It's also important to note: this only happens when Vector is a subclass of Matrix. If they are unrelated, then the result is as expected.


Solution

  • This is exactly because of inheritance. The order of operations are different if the right operand is a subclass of the left operand. This allows the subclass to override the behavior of the superclass. Let's see this in action with a few examples. The normal order of operations is listed first, then a few elaborate examples, and finally the subclass example:

    Example 1: Normal oder of operation #1

    class MyInt1():
        def __init__(self, val=0):
            self.val = val
    
        def __int__(self):
            return self.val
    
        def __mul__(self, other):
            print('in MyInt1.__mul__()')
            return self.val * other.val
    
    
    class MyInt2():
        def __init__(self, val=0):
            self.val = val
    
        def __rmul__(self, other):
            print('in MyInt2.__rmul__()')
            return other * self.val
    
    
    if __name__ == '__main__':
        a = MyInt1(1)
        b = MyInt2(3)
        c = a * b
    

    Outputs:

    in MyInt1.__mul__()
    

    In the above code, the class MyInt1 defines the __mul__() method, which returns the product of self.val and other.val. Note that MyInt2's __rmul__() was never called due to the multiplicaation already being performed and resolved before reaching that step. Here's another code:

    class MyInt1():
        def __init__(self, val=0):
            self.val = val
    
        def __int__(self):
            return self.val
    
        def __mul__(self, other):
            print('in MyInt1.__mul__()')
            return self.val * other
    
    
    class MyInt2():
        def __init__(self, val=0):
            self.val = val
    
        def __rmul__(self, other):
            print('in MyInt2.__rmul__()')
            return other * self.val
    
    
    if __name__ == '__main__':
        a = MyInt1(1)
        b = MyInt2(3)
        c = a * b
    

    Outputs:

    in MyInt1.__mul__()
    in MyInt2.__rmul__()
    

    This code, although extremely similar to the above, is actually quite different (Notice the self.val * other instead of self.val * other.val in MyInt1.__mul__(). This makes all the difference). This code goes through three steps:

    Now, let's see another code that, although doesn't make much sense, shows the concept more clearly:

    class MyInt1():
        def __mul__(self, other):
            print('in MyInt1.__mul__()')
            return 2 * 2
    
        def __rmul__(self, other):
            print('in MyInt1.__rmul__()')
            return 4 * 4
    
    
    class MyInt2():
        def __mul__(self, other):
            print('in MyInt2.__mul__()')
            return 3 * 3
    
        def __rmul__(self, other):
            print('in MyInt2.__rmul__()')
            return 6 * 6
    
    
    if __name__ == '__main__':
        a = MyInt1()
        b = MyInt2()
        c = a * b
        print(c)
    

    Ouputs:

    in MyInt1.__mul__()
    4
    

    This code is another example of the normal order of operations. Now let's see this code but with inheritance involved:

    class MyInt1():
        def __mul__(self, other):
            print('in MyInt1.__mul__()')
            return 2 * 2
    
        def __rmul__(self, other):
            print('in MyInt1.__rmul__()')
            return 4 * 4
    
    
    class MyInt2(MyInt1):
        def __mul__(self, other):
            print('in MyInt2.__mul__()')
            return 3 * 3
    
        def __rmul__(self, other):
            print('in MyInt2.__rmul__()')
            return 6 * 6
    
    
    if __name__ == '__main__':
        a = MyInt1()
        b = MyInt2()
        c = a * b
        d = b * a
        print(c)
        print(d)
    

    Outputs:

    in MyInt2.__rmul__()
    36
    in MyInt2.__mul__()
    9
    
    

    As you can see in this case, whether the subclass is the left or right operand doesn't matter anymore. In both cases the subclass's methods are called. If the subclass is the left operand then there's not much difference from the normal order, as left is called before right. However, in the right operand case, the preference still goes to the subclass, thus the right operand. This behavior is to let subclasses override their superclasses's behaviors.

    Reference: See __add__ and __radd__ special order of preference for more.