pythonpandasdataframepandas-rolling

`pandas` rolling sum with a maximum number of valid observations in a window


I am looking for help to speed up a rolling calculation in pandas which would compute a rolling average with a predefined maximum number of most recent observations. Here is code to generate an example frame and the frame itself:

import pandas as pd
import numpy as np

tmp = pd.DataFrame(
    [
        [11.1]*3 + [12.1]*3 + [13.1]*3  + [14.1]*3 + [15.1]*3 + [16.1]*3 + [17.1]*3 + [18.1]*3,
        ['A', 'B', 'C']*8,
        [np.nan]*6 + [1, 1, 1] + [2, 2, 2] + [3, 3, 3] + [np.nan]*9
    ],
    index=['Date', 'Name', 'Val']
)
tmp = tmp.T.pivot(index='Date', columns='Name', values='Val')

Name    A    B    C
Date               
11.1  NaN  NaN  NaN
12.1  NaN  NaN  NaN
13.1    1    1    1
14.1    2    2    2
15.1    3    3    3
16.1  NaN  NaN  NaN
17.1  NaN  NaN  NaN
18.1  NaN  NaN  NaN

I would like to obtain this result:

Name    A    B    C
Date               
11.1  NaN  NaN  NaN
12.1  NaN  NaN  NaN
13.1  1.0  1.0  1.0
14.1  1.5  1.5  1.5
15.1  2.5  2.5  2.5
16.1  2.5  2.5  2.5
17.1  3.0  3.0  3.0
18.1  NaN  NaN  NaN

Attempted Solution

I tried the following code and it works, but its performance is very bad for data sets that I am stuck with in practice.

tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(), raw=True)

Calculation above applied to a 3k x 50k frame takes about 20 minutes... Maybe there is a more elegant and faster way to obtain the same result? Maybe using a combination of multiple rolling computation results or something with groupby?

Versions

Python - 3.9.13, pandas - 2.0.3 and numpy - 1.25.2


Solution

  • One idea is use numba for faster count output in Rolling.apply by parameter engine='numba':

    (tmp.rolling(window=3, min_periods=1)
        .apply(lambda x: x[~np.isnan(x)][-2:].mean(), engine='numba', raw=True))
    

    Test performance:

    tmp = pd.concat([tmp] * 100000, ignore_index=True)
    
    In [88]: %timeit tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(),engine='numba', raw=True)
    901 ms ± 6.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [89]: %timeit tmp.rolling(window=3, min_periods=1).apply(lambda x: x[~np.isnan(x)][-2:].mean(), raw=True)
    13 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    Numpy approach:

    You can convert DataFrame to 3d array with append first NaNs values, then shift non NaNs and get means:

    #https://stackoverflow.com/a/44559180/2901002
    def justify_nd(a, invalid_val, axis, side):    
        """
        Justify ndarray for the valid elements (that are not invalid_val).
    
        Parameters
        ----------
        A : ndarray
            Input array to be justified
        invalid_val : scalar
            invalid value
        axis : int
            Axis along which justification is to be made
        side : str
            Direction of justification. Must be 'front' or 'end'.
            So, with 'front', valid elements are pushed to the front and
            with 'end' valid elements are pushed to the end along specified axis.
        """
        
        pushax = lambda a: np.moveaxis(a, axis, -1)
        if invalid_val is np.nan:
            mask = ~np.isnan(a)
        else:
            mask = a!=invalid_val
        justified_mask = np.sort(mask,axis=axis)
        
        if side=='front':
            justified_mask = np.flip(justified_mask,axis=axis)
                
        out = np.full(a.shape, invalid_val)
        if (axis==-1) or (axis==a.ndim-1):
            out[justified_mask] = a[mask]
        else:
            pushax(out)[pushax(justified_mask)] = pushax(a)[pushax(mask)]
        return out
    

    from numpy.lib.stride_tricks import sliding_window_view as swv
    
    window_size = 3
    N = 2
    
    a = tmp.astype(float).to_numpy()
    arr = np.vstack([np.full((window_size-1,a.shape[1]), np.nan),a])
    
    out = np.nanmean(justify_nd(swv(arr, window_size, axis=0), 
                                invalid_val=np.nan, axis=2, side='end')[:, :, -N:], 
                     axis=2)
    
    print (out)
    [[nan nan nan]
     [nan nan nan]
     [1.  1.  1. ]
     [1.5 1.5 1.5]
     [2.5 2.5 2.5]
     [2.5 2.5 2.5]
     [3.  3.  3. ]
     [nan nan nan]]
    

    df = pd.DataFrame(out, index=tmp.index, columns=tmp.columns)
    print (df)
    Name    A    B    C
    Date               
    11.1  NaN  NaN  NaN
    12.1  NaN  NaN  NaN
    13.1  1.0  1.0  1.0
    14.1  1.5  1.5  1.5
    15.1  2.5  2.5  2.5
    16.1  2.5  2.5  2.5
    17.1  3.0  3.0  3.0
    18.1  NaN  NaN  NaN
    

    Performance:

    tmp = pd.concat([tmp] * 100000, ignore_index=True)
    
    
    In [99]: %%timeit
        ...: a = tmp.astype(float).to_numpy()
        ...: arr = np.vstack([np.full((window_size-1,a.shape[1]), np.nan),a])
        ...: 
        ...: out = np.nanmean(justify_nd(swv(arr, window_size, axis=0), 
        ...:                             invalid_val=np.nan, 
                                         axis=2, side='end')[:, :, -N:], axis=2)
        ...: 
        ...: df = pd.DataFrame(out, index=tmp.index, columns=tmp.columns)
        ...: 
    
    338 ms ± 4.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)