pythonnumpyoverloading

How to get a python function to work on an np.array or a float, with conditional logic


I have a function that I'd like to take numpy arrays or floats as input. I want to keep doing an operation until some measure of error is less than a threshold.

A simple example would be the following to divide a number or array by 2 until it's below a threshold (if a float), or until it's maximum is below a threshold (if an array).

def f(x):   #float version
    while x>1e-5:
       x = x/2
    return x

def f(x):    #np array version
    while max(x)>1e-5:
       x = x/2
    return x

Unfortunately, max won't work if I've got something that is not iterable, and x>1e-5 won't work if x is an array. I can't find anything to do this, except perhaps vectorize, but that seems to not be as efficient as I would want. How can I get a single function to handle both cases?


Solution

  • What about checking the type of input inside the function and adapt it ?

    def f(x):    # for float or np.array
        if type(x) is float:
            x = x * np.ones(1)
        while np.max(x)>1e-5:
           x = x/2
        return x