I'm interested in finding the length of sequences of 1's along a single axis in a multi-dimension array.
For a 1D array I have worked my way to a solution using the answers at this older question. E.g. [0,1,0,0,1,1,1,0,1,1] --> [nan,1,nan,nan,3,nan,nan,nan,2,nan]
For a 3D array I can of course create a loop, but I'd rather not. (Background climate science, looping over all latitude/longitude grid cells is going to make this very slow.)
I'm trying to find a solution in line with the 1D solution. Help would be much appreciated, in line with the 1D code, but complete different solutions welcome too of course.
For reference, this is my working 1D solution:
import xarray as xr
import numpy as np
def run_lengths(da):
n = len(da)
y = da.values[1:] != da.values[:-1]
i = np.append(np.where(y), n - 1)
z = np.diff(np.append(-1,i))
p = np.cumsum(np.append(0,z))[:-1]
runs = np.where(da[i]==1)[0]
runs_len = z[runs] # length of sequence
time_val = da.time[p[runs]] # date of first day in sequence
da_runs = xr.DataArray(runs_len,coords={'time':time_val})
_,da_runs = xr.align(da,da_runs,join='outer') # make sure we have full time axis
return da_runs
da = xr.DataArray(np.array([[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]],[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]]]),coords={'lat':[0,1],'lon':[0,1,2],'time':[0,1,2,3,4,5]})
da_runs = run_lengths(da[0,1])
print(da_runs)
<xarray.DataArray (time: 6)>
array([ 1., nan, 2., nan, nan, 1.])
Coordinates:
* time (time) int64 0 1 2 3 4 5
And this is the attempt in 3D. I'm stuck on how to shift the valid entries in i
to the front/remove NaNs from i
. (And maybe beyond that as well?)
def run_lengths_3D(da):
n = len(da.time)
y = da.values[:,:,1:] != da.values[:,:,:-1]
y = xr.DataArray(y,coords={'lat':da.lat,'lon':da.lon,'time':da.time[0:-1]})
i = y.where(y)*xr.DataArray(np.arange(0,len(da.time[0:-1])),coords={'time':y.time}) -1
For this task you can try to use numba, e.g.:
import numba
import numpy as np
@numba.njit
def calculate(arr):
out = np.empty_like(arr, dtype="uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in range(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
arr = np.array(
[
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
]
)
print(calculate(arr))
Prints:
[[[0 2 0 0 0 0]
[1 0 2 0 0 1]
[4 0 0 0 0 1]]
[[0 2 0 0 0 0]
[1 0 2 0 0 1]
[4 0 0 0 0 1]]]
Benchmark using timeit
+ parallel version:
from timeit import timeit
import numba
import numpy as np
@numba.njit
def calculate(arr):
out = np.empty_like(arr, dtype="uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in range(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
@numba.njit(parallel=True)
def calculate_parallel(arr):
out = np.empty_like(arr, dtype="uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in numba.prange(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
arr = np.array(
[
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
],
dtype="uint8",
)
# compile calculate()/calculate_parallel()
assert np.allclose(calculate(arr), calculate_parallel(arr))
np.random.seed(42)
arr = np.random.randint(low=0, high=2, size=(256, 512, 3650), dtype="uint8")
t_serial = timeit("calculate(arr)", number=1, globals=globals())
t_parallel = timeit("calculate_parallel(arr)", number=1, globals=globals())
print(f"{t_serial * 1_000_000:.2f} usec")
print(f"{t_parallel * 1_000_000:.2f} usec")
Prints on my machine (AMD 5700x):
1575227.47 usec
320453.57 usec