pythonequalityequivalence

Elegant ways to support equivalence ("equality") in Python classes


When writing custom classes it is often important to allow equivalence by means of the == and != operators. In Python, this is made possible by implementing the __eq__ and __ne__ special methods, respectively. The easiest way I've found to do this is the following method:

class Foo:
    def __init__(self, item):
        self.item = item

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        else:
            return False

    def __ne__(self, other):
        return not self.__eq__(other)

Do you know of more elegant means of doing this? Do you know of any particular disadvantages to using the above method of comparing __dict__s?

Note: A bit of clarification--when __eq__ and __ne__ are undefined, you'll find this behavior:

>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
False

That is, a == b evaluates to False because it really runs a is b, a test of identity (i.e., "Is a the same object as b?").

When __eq__ and __ne__ are defined, you'll find this behavior (which is the one we're after):

>>> a = Foo(1)
>>> b = Foo(1)
>>> a is b
False
>>> a == b
True

Solution

  • Consider this simple problem:

    class Number:
    
        def __init__(self, number):
            self.number = number
    
    
    n1 = Number(1)
    n2 = Number(1)
    
    n1 == n2 # False -- oops
    

    So, Python by default uses the object identifiers for comparison operations:

    id(n1) # 140400634555856
    id(n2) # 140400634555920
    

    Overriding the __eq__ function seems to solve the problem:

    def __eq__(self, other):
        """Overrides the default implementation"""
        if isinstance(other, Number):
            return self.number == other.number
        return False
    
    
    n1 == n2 # True
    n1 != n2 # True in Python 2 -- oops, False in Python 3
    

    In Python 2, always remember to override the __ne__ function as well, as the documentation states:

    There are no implied relationships among the comparison operators. The truth of x==y does not imply that x!=y is false. Accordingly, when defining __eq__(), one should also define __ne__() so that the operators will behave as expected.

    def __ne__(self, other):
        """Overrides the default implementation (unnecessary in Python 3)"""
        return not self.__eq__(other)
    
    
    n1 == n2 # True
    n1 != n2 # False
    

    In Python 3, this is no longer necessary, as the documentation states:

    By default, __ne__() delegates to __eq__() and inverts the result unless it is NotImplemented. There are no other implied relationships among the comparison operators, for example, the truth of (x<y or x==y) does not imply x<=y.

    But that does not solve all our problems. Let’s add a subclass:

    class SubNumber(Number):
        pass
    
    
    n3 = SubNumber(1)
    
    n1 == n3 # False for classic-style classes -- oops, True for new-style classes
    n3 == n1 # True
    n1 != n3 # True for classic-style classes -- oops, False for new-style classes
    n3 != n1 # False
    

    Note: Python 2 has two kinds of classes:

    For classic-style classes, a comparison operation always calls the method of the first operand, while for new-style classes, it always calls the method of the subclass operand, regardless of the order of the operands.

    So here, if Number is a classic-style class:

    And if Number is a new-style class:

    To fix the non-commutativity issue of the == and != operators for Python 2 classic-style classes, the __eq__ and __ne__ methods should return the NotImplemented value when an operand type is not supported. The documentation defines the NotImplemented value as:

    Numeric methods and rich comparison methods may return this value if they do not implement the operation for the operands provided. (The interpreter will then try the reflected operation, or some other fallback, depending on the operator.) Its truth value is true.

    In this case the operator delegates the comparison operation to the reflected method of the other operand. The documentation defines reflected methods as:

    There are no swapped-argument versions of these methods (to be used when the left argument does not support the operation but the right argument does); rather, __lt__() and __gt__() are each other’s reflection, __le__() and __ge__() are each other’s reflection, and __eq__() and __ne__() are their own reflection.

    The result looks like this:

    def __eq__(self, other):
        """Overrides the default implementation"""
        if isinstance(other, Number):
            return self.number == other.number
        return NotImplemented
    
    def __ne__(self, other):
        """Overrides the default implementation (unnecessary in Python 3)"""
        x = self.__eq__(other)
        if x is NotImplemented:
            return NotImplemented
        return not x
    

    Returning the NotImplemented value instead of False is the right thing to do even for new-style classes if commutativity of the == and != operators is desired when the operands are of unrelated types (no inheritance).

    Are we there yet? Not quite. How many unique numbers do we have?

    len(set([n1, n2, n3])) # 3 -- oops
    

    Sets use the hashes of objects, and by default Python returns the hash of the identifier of the object. Let’s try to override it:

    def __hash__(self):
        """Overrides the default implementation"""
        return hash(tuple(sorted(self.__dict__.items())))
    
    len(set([n1, n2, n3])) # 1
    

    The end result looks like this (I added some assertions at the end for validation):

    class Number:
    
        def __init__(self, number):
            self.number = number
    
        def __eq__(self, other):
            """Overrides the default implementation"""
            if isinstance(other, Number):
                return self.number == other.number
            return NotImplemented
    
        def __ne__(self, other):
            """Overrides the default implementation (unnecessary in Python 3)"""
            x = self.__eq__(other)
            if x is not NotImplemented:
                return not x
            return NotImplemented
    
        def __hash__(self):
            """Overrides the default implementation"""
            return hash(tuple(sorted(self.__dict__.items())))
    
    
    class SubNumber(Number):
        pass
    
    
    n1 = Number(1)
    n2 = Number(1)
    n3 = SubNumber(1)
    n4 = SubNumber(4)
    
    assert n1 == n2
    assert n2 == n1
    assert not n1 != n2
    assert not n2 != n1
    
    assert n1 == n3
    assert n3 == n1
    assert not n1 != n3
    assert not n3 != n1
    
    assert not n1 == n4
    assert not n4 == n1
    assert n1 != n4
    assert n4 != n1
    
    assert len(set([n1, n2, n3, ])) == 1
    assert len(set([n1, n2, n3, n4])) == 2