pythonarraysnumpyastropy

How to pass a numpy array though a function when the function contains conditional if statements?


I have the following code:

import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt

cosmopar = FlatLambdaCDM(H0 = 67.8,Om0 = 0.3)

td_min = 0.1
d = -1

def t_L(z_arb):
    return cosmopar.lookback_time(z_arb).value

def t_d(z_f,z_m):
    return t_L(z_f)-t_L(z_m)

def P_t(z_f,z_m):
    if (td_min<t_d):
        return t_d**d
    else:
        return 0

Now if I define a numpy array zf_trial1 = np.linspace(0,30,100) and try to pass it through the function using the command P_t(zf_trial1,3), the function returns the following error statement:

"The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"

Now I understand why this error is popping up - when making the comparizon with td_min in the if statement, passing an 'array' with several elements leads to certain elements of the array satisfying the condition of the if statement and certain elements failing the condition; however, I am not sure how to fix this. Overall, all I want to do is pass each element of the NumPy array zf_trial1 through P_t(z_f,z_m).

I tried the np.vectorize() function but this doesn't seem to be working all that well and the results seem to be haywire since when I plot the results of the function, the graph I am receiving is different from the one I am receiving if I manually input values into the P_t function and then plot it. What I tried is as follows:

Pt_vector = np.vectorize(P_t)
Pt_res = Pt_vector(zf_trial1,3)

plt.scatter(zf_trial1,Pt_res)

Solution

  • I'm a bit confused since P_t takes z_f and z_m as argument but they aren't used in the function at all, thus I don't really get how it should be used, thus I'll try answer the question in a general way.

    You can use np.where to make a filter, in your function, like so

    def ex(arr,max_value):
          """
          arr: np.array
          max_value: int
    
    
          All values in `arr` below `max_value`
          are raised to a power of `d`.
          Values below are set as `0`
          """
          return np.where(arr < max_value, arr**d, 0)
    
    d = 2
    ex(np.arange(0,10),5) # array([ 0,  1,  4,  9, 16,  0,  0,  0,  0,  0])
          
    

    or you can just use list-comprehension :

    a = np.array([P_t(v,z_m) for v in zf_trial1])