I am looking for help to speed up a rolling calculation in pandas
which would compute a rolling average with a predefined maximum number of most recent observations. Here is code to generate an example frame and the frame itself:
import pandas as pd
import numpy as np
tmp = pd.DataFrame(
[
[11.1]*3 + [12.1]*3 + [13.1]*3 + [14.1]*3 + [15.1]*3 + [16.1]*3 + [17.1]*3 + [18.1]*3,
['A', 'B', 'C']*8,
[np.nan]*6 + [1, 1, 1] + [2, 2, 2] + [3, 3, 3] + [np.nan]*9
],
index=['Date', 'Name', 'Val']
)
tmp = tmp.T.pivot(index='Date', columns='Name', values='Val')
Name A B C
Date
11.1 NaN NaN NaN
12.1 NaN NaN NaN
13.1 1 1 1
14.1 2 2 2
15.1 3 3 3
16.1 NaN NaN NaN
17.1 NaN NaN NaN
18.1 NaN NaN NaN
I would like to obtain this result:
Name A B C
Date
11.1 NaN NaN NaN
12.1 NaN NaN NaN
13.1 1.0 1.0 1.0
14.1 1.5 1.5 1.5
15.1 2.5 2.5 2.5
16.1 2.5 2.5 2.5
17.1 3.0 3.0 3.0
18.1 NaN NaN NaN
I tried the following code and it works, but its performance is very bad for data sets that I am stuck with in practice.
tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(), raw=True)
Calculation above applied to a 3k x 50k frame takes about 20 minutes... Maybe there is a more elegant and faster way to obtain the same result? Maybe using a combination of multiple rolling computation results or something with groupby
?
Python - 3.9.13, pandas
- 2.0.3 and numpy
- 1.25.2
One idea is use numba
for faster count output in Rolling.apply
by parameter engine='numba'
:
(tmp.rolling(window=3, min_periods=1)
.apply(lambda x: x[~np.isnan(x)][-2:].mean(), engine='numba', raw=True))
Test performance:
tmp = pd.concat([tmp] * 100000, ignore_index=True)
In [88]: %timeit tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(),engine='numba', raw=True)
901 ms ± 6.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [89]: %timeit tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(), raw=True)
13 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Numpy approach:
You can convert DataFrame to 3d array with append first NaN
s values, then shift non NaN
s and get means:
#https://stackoverflow.com/a/44559180/2901002
def justify_nd(a, invalid_val, axis, side):
"""
Justify ndarray for the valid elements (that are not invalid_val).
Parameters
----------
A : ndarray
Input array to be justified
invalid_val : scalar
invalid value
axis : int
Axis along which justification is to be made
side : str
Direction of justification. Must be 'front' or 'end'.
So, with 'front', valid elements are pushed to the front and
with 'end' valid elements are pushed to the end along specified axis.
"""
pushax = lambda a: np.moveaxis(a, axis, -1)
if invalid_val is np.nan:
mask = ~np.isnan(a)
else:
mask = a!=invalid_val
justified_mask = np.sort(mask,axis=axis)
if side=='front':
justified_mask = np.flip(justified_mask,axis=axis)
out = np.full(a.shape, invalid_val)
if (axis==-1) or (axis==a.ndim-1):
out[justified_mask] = a[mask]
else:
pushax(out)[pushax(justified_mask)] = pushax(a)[pushax(mask)]
return out
from numpy.lib.stride_tricks import sliding_window_view as swv
window_size = 3
N = 2
a = tmp.astype(float).to_numpy()
arr = np.vstack([np.full((window_size-1,a.shape[1]), np.nan),a])
out = np.nanmean(justify_nd(swv(arr, window_size, axis=0),
invalid_val=np.nan, axis=2, side='end')[:, :, -N:],
axis=2)
print (out)
[[nan nan nan]
[nan nan nan]
[1. 1. 1. ]
[1.5 1.5 1.5]
[2.5 2.5 2.5]
[2.5 2.5 2.5]
[3. 3. 3. ]
[nan nan nan]]
df = pd.DataFrame(out, index=tmp.index, columns=tmp.columns)
print (df)
Name A B C
Date
11.1 NaN NaN NaN
12.1 NaN NaN NaN
13.1 1.0 1.0 1.0
14.1 1.5 1.5 1.5
15.1 2.5 2.5 2.5
16.1 2.5 2.5 2.5
17.1 3.0 3.0 3.0
18.1 NaN NaN NaN
Performance:
tmp = pd.concat([tmp] * 100000, ignore_index=True)
In [99]: %%timeit
...: a = tmp.astype(float).to_numpy()
...: arr = np.vstack([np.full((window_size-1,a.shape[1]), np.nan),a])
...:
...: out = np.nanmean(justify_nd(swv(arr, window_size, axis=0),
...: invalid_val=np.nan,
axis=2, side='end')[:, :, -N:], axis=2)
...:
...: df = pd.DataFrame(out, index=tmp.index, columns=tmp.columns)
...:
338 ms ± 4.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)