pandasdataframegroup-byrolling-average

Rolling Average with variable min_periods from another column


I have a dataframe with multiple accounts across the last few years and am trying to get the rolling average of a column, per account, which is easy enough. However I also need to have the min_periods be variable. I have a column for the min_periods in the dataframe (minuploads) and the window is always 10. I just cannot find how to go about doing this. End of the day, I am trying to take a rolling average of the last 10 values, and each account falls into a different sort of category which requires a different minimum number of previous values which I have set as minuploads. If there are not enough previous values, I dont want the rolling average to calculate and will use a modeled value for those rows.

Here is a sample of the dataframe

accountID date diffs minuploads
1001 2020-01-12 7 8
1001 2020-01-19 7 8
1001 2020-01-20 1 8
1001 2020-01-26 6 8
1002 2020-01-10 2 10
1002 2020-01-13 3 10
1002 2020-01-16 3 10
1002 2020-01-18 2 10

Initially I didnt use a variable min_periods and used this code for the rolling average:

df['ma'] = df.groupby('accountID')['diffs'].transform(lambda x: x.rolling(10,5).mean())

Which worked great, but then decided I need to do a variable min_period. I tried this but I get a key error

df['ma'] = df.groupby('accountID')['diffs'].transform(lambda x: x.rolling(10,x['minuploads']).mean())

Also tried this to make sure the groupby had both columns, but still get a KeyError: 'minuploads'

df['ma'] = df.groupby('accountID')[['diffs','minuploads']].transform(lambda x: x.rolling(10,x['minuploads']).mean())

Solution

  • minuploads shouldn't be a column because it's not per-row, it's per-account. You can just use a dict, or a Series if you prefer.

    I've had to change the numbers to work with this example data. (8 and 10 to 2 and 3)

    minuploads = {
        1001: 2,
        1002: 3,
    }
    df['ma'] = df.groupby('accountID')['diffs'].transform(
        lambda x: x.rolling(10, minuploads[x.name]).mean())
    

    Result:

       accountID        date  diffs        ma
    0       1001  2020-01-12      7       NaN
    1       1001  2020-01-19      7  7.000000
    2       1001  2020-01-20      1  5.000000
    3       1001  2020-01-26      6  5.250000
    4       1002  2020-01-10      2       NaN
    5       1002  2020-01-13      3       NaN
    6       1002  2020-01-16      3  2.666667
    7       1002  2020-01-18      2  2.500000