pythonnumpyoperator-overloadingnumpy-ufunc

How can I make my class more robust to operator/function overloading?


I am writing a class for which objects are initialised with two parameters (a, b). The intention is to assign instances of this class to variables so that I can have an equation written symbolically in Python code, but have operator overloading perform unique operations on a and b.

import numpy as np


class my_class(object):

    def __init__(self, a, b):
        self.value1 = a
        self.value2 = b

    # Example of an overloaded operator that works just fine
    def __mul__(self, other):
        new_a = self.value1 * other
        new_b = self.value2 * np.absolute(other)
        return my_class(new_a, new_b)


if __name__ == "__main__":
    my_object = my_class(100, 1)

    print(np.exp(my_object))    # This doesn't work!

In running the above example code, I encountered the following output:

TypeError: loop of ufunc does not support argument 0 of type my_class which has no callable exp method

Through guesswork, I was able to see that a complaint about no callable exp method probably meant I needed to define a new method in my class:

def exp(self):
    new_val1 = np.exp(self.value1)
    new_val2 = np.absolute(new_val1) * self.value2
    return my_class(new_val1, new_val2)

which ended up working just fine. But now I will have to write another method for np.expm1() and so on as I require. Thankfully I only need np.exp() and np.log() to work, but I also tried math.exp() on my object and I started getting a type error.

So now my question is: The custom exp method in the class seemed to work for overloading the NumPy function, but how am I supposed to handle math.exp() not working? It must be possible because somehow when calling math.exp() on a NumPy array, NumPy understands that a 1-element array can be turned into a scalar and then passed to math.exp() without issue. I mean I guess this technically is about overloading a function, but before I realised defining a new exp was the fix to my first problem, I had no idea why a method like __rpow__ wasn't being called.


Solution

  • np.exp(my_object) is implemented as np.exp(np.array(my_object)).

    np.array(my_object) is a object dtype array. np.exp tries elmt.exp() for each element of the array. That doesn't work for most classes, since they don't implement such a method.

    Same applies for operators and other ufunc.

    math.exp is an unrelated implementation. It apparently works for something that gives a single numeric value, but I haven't explored that much. numpy will raise an error if it can't do that.

    Implementing * with a class __mul__ is done by interpreter.


    Same error message when using array in math.exp and with __float__()

    In [52]: math.exp(np.array([1,2,3]))
    Traceback (most recent call last):
      File "<ipython-input-52-40503a52084a>", line 1, in <module>
        math.exp(np.array([1,2,3]))
    TypeError: only size-1 arrays can be converted to Python scalars
    
    In [53]: np.array([1,2,3]).__float__()
    Traceback (most recent call last):
      File "<ipython-input-53-0bacdf9df4e7>", line 1, in <module>
        np.array([1,2,3]).__float__()
    TypeError: only size-1 arrays can be converted to Python scalars
    

    Similarly when an array is used in a boolean context (e.g if), we can get an error generated with

    In [55]: np.array([1,2,3]).__bool__()
    Traceback (most recent call last):
      File "<ipython-input-55-04aca1612817>", line 1, in <module>
        np.array([1,2,3]).__bool__()
    ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    Similarly using a sympy Relational in an if results in the error produced by

    In [110]: (x>0).__bool__()
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-110-d88b76ce6b22> in <module>
    ----> 1 (x>0).__bool__()
    
    /usr/local/lib/python3.8/dist-packages/sympy/core/relational.py in __bool__(self)
        396 
        397     def __bool__(self):
    --> 398         raise TypeError("cannot determine truth value of Relational")
        399 
        400     def _eval_as_set(self):
    
    TypeError: cannot determine truth value of Relational
    

    pandas Series produce a similar error.