pythonoopinheritance

define multiple methods for comparison at once using the same principle


I am building a binary tree in python defining the node as a class. I wanted the node to have a value and be comparable to other nodes in order to, for example, sort a list of them.

I wanted to know if there is a more elegant way to avoid explicitly defining all comparison methods (__le__, __lt__, __eq__). I tried this and it works:

class Node:
  def __init__(self, value, name=None, left=None, right=None):
    self.value=value
    self.name = name
    self.right, self.left = right, left
  
  def is_leaf(self):
    return self.right is None and self.left is None
  
  def __le__(self, other):
    return self.value.__le__(other.value) # or (self.value <= other.value)
  # same for __lt__, __eq__

But I wanted to reuse the code. More generally, I want the object to reference self.value for a list of dunder methods, without explicitly coding each one.

I considered forcing inheritence from same base class, e.g. if values are numbers: class Node(float), or, in the init:

   def __init__(self, value, name=None, left=None, right=None):
     type(value).__init__(value)
     # etc.

But they seem to me as a bad practice since they add a lot of potentially unexpected behaviours.

Is there a pythonic / elegant way to avoid explicitly defining all comparison methods in the class when the all obey a common standard?


Solution

  • IIUC, you can use functools.total_ordering:

    from functools import total_ordering
    
    
    @total_ordering
    class Node:
        def __init__(self, value, name=None, left=None, right=None):
            self.value = value
            self.name = name
            self.right, self.left = right, left
    
        def is_leaf(self):
            return self.right is None and self.left is None
    
        def __le__(self, other):
            return self.value < other
    
        def __eq__(self, other):
            return self.value == other
    
    
    n1 = Node(10)
    n2 = Node(10)
    n3 = Node(20)
    
    print(f"{n1 < n2 = }")
    print(f"{n1 > n2 = }")
    print(f"{n1 == n2 = }")
    print(f"{n1 < n3 = }")
    print(f"{n1 > n3 = }")
    print(f"{n1 > 5 = }")
    

    Prints:

    n1 < n2 = False
    n1 > n2 = False
    n1 == n2 = True
    n1 < n3 = True
    n1 > n3 = False
    n1 > 5 = True