I am trying to extract, from an array, all values within a certain slice
(something like a range
, but with optional start
, stop
and step
). And in that, I want to benefit from the heavy optimizations that range
objects employ for range.__contains__()
, which implies they don't ever have to instantiate the full range of values (compare Why is "1000000000000000 in range(1000000000000001)" so fast in Python 3?).
The following code works, but it's horribly inefficient because i
is converted to a full-fledged array, increasing memory use and runtime.
import numpy as np
arr = np.array([0, 20, 29999999, 10, 30, 40, 50])
M = np.max(arr)
# slice based on values
s = slice(20, None)
i = range(*s.indices(M + 1))
print(arr[np.isin(arr, i)]) # works, but inefficient!
Output:
[ 20 29999999 30 40 50]
Is there a numpy function to improve that directly? Should I use np.vectorize
/np.where
instead with a callback using the slice (which feels like it could be slow as well, going back from C++ to Python for every single element [or would it not be doing that?])? Should I subtract start
from my values, divide by step
, and then see if values are >= 0 && < (stop - start) / step
? Or am I missing a much better way?
Based on this comment, I think you just need to check >
and <
and return the nonzero indices from that result.
import numpy as np
def get_indices_within_range(arr, start, stop, step=None):
mask = np.ones_like(arr, dtype=bool)
if start is None and stop is None:
raise ValueError("At least one of start and stop must not be None.")
if start is not None:
mask *= arr >= start
if stop is not None:
mask *= arr < stop
if step is not None and step != 1:
if start is not None:
mask *= (arr - start)%step == 0
else:
mask *= arr%step == 0
if len(arr.shape) > 1:
return np.nonzero(mask)
else:
return np.nonzero(mask)[0]
rng = np.random.default_rng(2)
arr = rng.integers(0, 100, size=(20,))
print(arr)
indices = get_indices_within_range(arr, 20, 70)
print(indices)
indices = get_indices_within_range(arr, 20, None)
print(indices)
indices = get_indices_within_range(arr, None, 70)
print(indices)
indices = get_indices_within_range(arr, 20, 70, 10)
print(indices)
indices = get_indices_within_range(arr, None, 70, 10)
print(indices)
Results:
[83 26 10 29 41 81 45 9 33 60 81 72 99 18 88 5 55 27 20 65]
[ 1 3 4 6 8 9 16 17 18 19]
[ 0 1 3 4 5 6 8 9 10 11 12 14 16 17 18 19]
[ 1 2 3 4 6 7 8 9 13 15 16 17 18 19]
[ 9 18]
[ 2 9 18]