pythonpython-typing

Python: How to Type Hint a Class Argument in a Static Method Python?


Let's say I have a class Circle with the following definition:

class Circle:
    def __init__(self, r, _id):
        self.r = r
        self.id =  _id

    def area(self):
        return math.pi * (self.r ** 2)

I want to write a function that compares two circles and returns the id of the smallest one

def compare_circles(circle_1: Circle, circle_2: Circle) -> str:
    if circle_1.r < circle_2.r:
        return circle_1.id
    else:
        return circle_2.id

I would like to place this method as a static method on the class. (Is this a bad idea?)

class Circle:
    
    def __init__(self, r, _id):
        self.r = r
        self.id =  _id

    def area(self):
        return math.pi * (self.r ** 2)

    @staticmethod
    def compare_circles(circle_1: Circle, circle_2: Circle) -> str:
        if circle_1.r < circle_2.r:
            return circle_1.id
        else:
            return circle_2.id

However this does not work because the Circle object is not available to be type hinted. I have to remove the type hints in the method. Using a classmethod also does not work, as the cls argument is not available to the other method arguments.


Solution

  • A possible workaround in this case would be to monkey-patch the method after defining the class:

    class Circle:
        def __init__(self, r, _id):
            self.r = r
            self.id =  _id
    
        def area(self):
            return math.pi * (self.r ** 2)
    
    
    def compare_circles(circle_1: Circle, circle_2: Circle) -> str:
        if circle_1.r < circle_2.r:
            return circle_1.id
        else:
            return circle_2.id
    
    Circle.compare_circles = staticmethod(compare_circles)
    del compare_circles
    

    The usual way would be to provide a string with the type name:

    class Circle:
        def __init__(self, r, _id):
            self.r = r
            self.id =  _id
    
        def area(self):
            return math.pi * (self.r ** 2)
    
    
        @staticmethod
        def compare_circles(circle_1: 'Circle', circle_2: 'Circle') -> str:
            if circle_1.r < circle_2.r:
                return circle_1.id
            else:
                return circle_2.id
    

    As an aside, you might also consider turning compare_circles into a method:

    def compare(self, other: 'Circle') -> str:
        if self.r < other.r:
            return self.id
        else:
            return other.id