Given a numpy array
arr = np.array([1, 2, 3, 4, 5])
I need to construct a binary mask
according to a (arbitrary, potentially long) list of values
, i.e. given
values = np.array([2, 4, 5])
mask
should be
mask = np.array([False, True, False, True, True])
So I want to avoid
condition = (arr==2) or (arr==4) or (arr==5)
mask = np.where(condition, arr)
to get something like
mask = np.in(values, arr)
Or, if it is not possible, how to construct a condition
from an arbitrary list of values to feed into np.where
?
In [64]: arr = np.array([1, 2, 3, 4, 5])
...: values = np.array([2, 4, 5])
While isin
is easy to use, it isn't the only option:
In [66]: np.isin(arr, values)
Out[66]: array([False, True, False, True, True])
We could compare the whole arrays:
In [67]: values[:,None]==arr
Out[67]:
array([[False, True, False, False, False],
[False, False, False, True, False],
[False, False, False, False, True]])
In [68]: (values[:,None]==arr).any(axis=0)
Out[68]: array([False, True, False, True, True])
Your use of or
does not work:
In [69]: (arr==values[0]) or (arr==values[1]) or (arr==values[2])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[69], line 1
----> 1 (arr==values[0]) or (arr==values[1]) or (arr==values[2])
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
We have to use the array logical_or:
In [70]: (arr==values[0]) | (arr==values[1]) | (arr==values[2])
Out[70]: array([False, True, False, True, True])
Depending on the relative size of the two array isin
may actually do this kind of logical_or.
Let's check the times (actual times will depend on the sizes):
In [71]: timeit (arr==values[0]) | (arr==values[1]) | (arr==values[2])
13.8 μs ± 75.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [72]: timeit np.isin(arr, values)
97 μs ± 242 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [73]: timeit (values[:,None]==arr).any(axis=0)
17 μs ± 101 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
isin
is convenient, but slowest in this example.