pythonarraysnumpy

np.where with a in-type condition


Given a numpy array

arr = np.array([1, 2, 3, 4, 5])

I need to construct a binary mask according to a (arbitrary, potentially long) list of values, i.e. given

values = np.array([2, 4, 5])

mask should be

mask = np.array([False, True, False, True, True])

So I want to avoid

condition = (arr==2) or (arr==4) or (arr==5)
mask = np.where(condition, arr)

to get something like

mask = np.in(values, arr)

Or, if it is not possible, how to construct a condition from an arbitrary list of values to feed into np.where?


Solution

  • In [64]: arr = np.array([1, 2, 3, 4, 5])
        ...: values = np.array([2, 4, 5])
    

    While isin is easy to use, it isn't the only option:

    In [66]: np.isin(arr, values)
    Out[66]: array([False,  True, False,  True,  True])
    

    We could compare the whole arrays:

    In [67]: values[:,None]==arr
    Out[67]: 
    array([[False,  True, False, False, False],
           [False, False, False,  True, False],
           [False, False, False, False,  True]])
    
    In [68]: (values[:,None]==arr).any(axis=0)
    Out[68]: array([False,  True, False,  True,  True])
    

    Your use of or does not work:

    In [69]: (arr==values[0]) or (arr==values[1]) or (arr==values[2])
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Cell In[69], line 1
    ----> 1 (arr==values[0]) or (arr==values[1]) or (arr==values[2])
    
    ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    We have to use the array logical_or:

    In [70]: (arr==values[0]) | (arr==values[1]) | (arr==values[2])
    Out[70]: array([False,  True, False,  True,  True])
    

    Depending on the relative size of the two array isin may actually do this kind of logical_or.

    Let's check the times (actual times will depend on the sizes):

    In [71]: timeit (arr==values[0]) | (arr==values[1]) | (arr==values[2])
    13.8 μs ± 75.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    
    In [72]: timeit np.isin(arr, values)
    97 μs ± 242 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    
    In [73]: timeit (values[:,None]==arr).any(axis=0)
    17 μs ± 101 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
    

    isin is convenient, but slowest in this example.