pythonpandaspandas-groupby

How can I compute a shifted expanding mean per group


I want expanding mean of col2 based on groupby('col1'), but I want the mean to not include the row itself (just the rows above it)

dummy = pd.DataFrame({"col1": ["a",'a','a','b','b','b','c','c'],"col2":['1','2','3','4','5','6','7','8'] }, index=list(range(8)))
print(dummy)
dummy['one_liner'] = dummy.groupby('col1').col2.shift().expanding().mean().reset_index(level=0, drop=True)
dummy['two_liner'] = dummy.groupby('col1').col2.shift()
dummy['two_liner'] = dummy.groupby('col1').two_liner.expanding().mean().reset_index(level=0, drop=True)
print(dummy)
---------------------------
here is result of first print statement:
col1 col2
0    a    1
1    a    2
2    a    3
3    b    4
4    b    5
5    b    6
6    c    7
7    c    8
here is result of the second print:
 col1 col2  one_liner  two_liner
0    a    1        NaN        NaN
1    a    2   1.000000        1.0
2    a    3   1.500000        1.5
3    b    4   1.500000        NaN
4    b    5   2.333333        4.0
5    b    6   3.000000        4.5
6    c    7   3.000000        NaN
7    c    8   3.800000        7.0

I would have thought their results would be identical. two_liner is the expected result. one_liner mixes numbers in between groups.

It took a long time to figure out this solution, can anyone explain the logic? Why does one_liner not give expected results?


Solution

  • You are looking for expanding().mean() and shift() within the groupby():

    groups = df.groupby('col1')
    df['one_liner'] = groups.col2.apply(lambda x: x.expanding().mean().shift())
    
    df['two_liner'] = groups.one_liner.apply(lambda x: x.expanding().mean().shift())
    

    Output:

      col1  col2  one_liner  two_liner
    0    a     1        NaN        NaN
    1    a     2        1.0        NaN
    2    a     3        1.5        1.0
    3    b     4        NaN        NaN
    4    b     5        4.0        NaN
    5    b     6        4.5        4.0
    6    c     7        NaN        NaN
    7    c     8        7.0        NaN
    

    Explanation:

    (dummy.groupby('col1').col2.shift()   # this shifts col2 within the groups 
         .expanding().mean()              # this ignores the grouping and expanding on the whole series
         .reset_index(level=0, drop=True) # this is not really important
    )
    

    So that the above chained command is equivalent to

    s1 = dummy.groupby('col1').col2.shift()
    s2 = s1.expanding.mean()
    s3 = s2.reset_index(level=0, drop=True)
    

    As you can see, only s1 considers the grouping by col1.