pythonpython-polarsrolling-computationpolars

Grouped Rolling Mean in Polars


Similar question is asked here

However it didn't seem to work in my case.

I have a dataframe with 3 columns, date, groups, prob. What I want is to create a 3 day rolling mean of the prob column values grouped by groups and date. However following the above linked answer I got all nulls returned.

import polars as pl
from datetime import date
import numpy as np

dates = pl.date_range(date(2024, 12, 1), date(2024, 12, 30), "1d", eager=True).alias(
    "date")
len(dates)
days = pl.concat([dates,dates])
groups = pl.concat([pl.select(pl.repeat("B", n = 30)).to_series(),
           pl.select(pl.repeat("A", n = 30)).to_series()]).alias('groups')

data = pl.DataFrame([days, groups])

data2 = data.with_columns(pl.lit(np.random.rand(data.height)).alias("prob"))

data2.with_columns(
    rolling_mean = 
    pl.col('prob')
    .rolling_mean(window_size = 3)
    .over('date','groups')
)

"""
shape: (60, 4)
┌────────────┬────────┬──────────┬──────────────┐
│ date       ┆ groups ┆ prob     ┆ rolling_mean │
│ ---        ┆ ---    ┆ ---      ┆ ---          │
│ date       ┆ str    ┆ f64      ┆ f64          │
╞════════════╪════════╪══════════╪══════════════╡
│ 2024-12-01 ┆ B      ┆ 0.938982 ┆ null         │
│ 2024-12-02 ┆ B      ┆ 0.103133 ┆ null         │
│ 2024-12-03 ┆ B      ┆ 0.724672 ┆ null         │
│ 2024-12-04 ┆ B      ┆ 0.495868 ┆ null         │
│ 2024-12-05 ┆ B      ┆ 0.621124 ┆ null         │
│ …          ┆ …      ┆ …        ┆ …            │
│ 2024-12-26 ┆ A      ┆ 0.762529 ┆ null         │
│ 2024-12-27 ┆ A      ┆ 0.766366 ┆ null         │
│ 2024-12-28 ┆ A      ┆ 0.272936 ┆ null         │
│ 2024-12-29 ┆ A      ┆ 0.28709  ┆ null         │
│ 2024-12-30 ┆ A      ┆ 0.403478 ┆ null         │
└────────────┴────────┴──────────┴──────────────┘
""""

In the documentation I found .rolling_mean_by and tried using it instead but instead of doing a rolling mean it seems to just return the prob value for each row.

data2.with_columns(
    rolling_mean = 
    pl.col('prob')
    .rolling_mean_by(window_size = '3d', by = 'date')
    .over('groups', 'date')
)

"""
shape: (60, 4)
┌────────────┬────────┬──────────┬──────────────┐
│ date       ┆ groups ┆ prob     ┆ rolling_mean │
│ ---        ┆ ---    ┆ ---      ┆ ---          │
│ date       ┆ str    ┆ f64      ┆ f64          │
╞════════════╪════════╪══════════╪══════════════╡
│ 2024-12-01 ┆ B      ┆ 0.938982 ┆ 0.938982     │
│ 2024-12-02 ┆ B      ┆ 0.103133 ┆ 0.103133     │
│ 2024-12-03 ┆ B      ┆ 0.724672 ┆ 0.724672     │
│ 2024-12-04 ┆ B      ┆ 0.495868 ┆ 0.495868     │
│ 2024-12-05 ┆ B      ┆ 0.621124 ┆ 0.621124     │
│ …          ┆ …      ┆ …        ┆ …            │
│ 2024-12-26 ┆ A      ┆ 0.762529 ┆ 0.762529     │
│ 2024-12-27 ┆ A      ┆ 0.766366 ┆ 0.766366     │
│ 2024-12-28 ┆ A      ┆ 0.272936 ┆ 0.272936     │
│ 2024-12-29 ┆ A      ┆ 0.28709  ┆ 0.28709      │
│ 2024-12-30 ┆ A      ┆ 0.403478 ┆ 0.403478     │
└────────────┴────────┴──────────┴──────────────┘
""""

Solution

  • Overall Problem. You group not only by group but also by date. This effectively performs the rolling operation separately for each group and date (i.e. separately for each row).

    Explanation of 1st attempt. As the groups are defined by the group and date columns, each group consists of a single row. This is lower than min_samples (equal to window_size by default), giving a None.

    Explanation of 2nd attempt. pl.Expr.rolling_mean_by does not have a min_samples argument. Therefore, the mean is computed, but only using the single element in the group, giving the perception that simply prob is returned.

    Solution. You can alleviate the issue by excluding date from the grouping defined in pl.Expr.over. This looks as follows.

    data2.with_columns(
        rolling_mean=pl.col('prob').rolling_mean_by(by="date", window_size="3d").over('groups')
    )
    
    shape: (60, 4)
    ┌────────────┬────────┬──────────┬──────────────┐
    │ date       ┆ groups ┆ prob     ┆ rolling_mean │
    │ ---        ┆ ---    ┆ ---      ┆ ---          │
    │ date       ┆ str    ┆ f64      ┆ f64          │
    ╞════════════╪════════╪══════════╪══════════════╡
    │ 2024-12-01 ┆ B      ┆ 0.484882 ┆ 0.484882     │
    │ 2024-12-02 ┆ B      ┆ 0.012538 ┆ 0.24871      │
    │ 2024-12-03 ┆ B      ┆ 0.510953 ┆ 0.336124     │
    │ 2024-12-04 ┆ B      ┆ 0.613973 ┆ 0.379155     │
    │ 2024-12-05 ┆ B      ┆ 0.69837  ┆ 0.607765     │
    │ …          ┆ …      ┆ …        ┆ …            │
    │ 2024-12-26 ┆ A      ┆ 0.948971 ┆ 0.653762     │
    │ 2024-12-27 ┆ A      ┆ 0.905213 ┆ 0.622917     │
    │ 2024-12-28 ┆ A      ┆ 0.986094 ┆ 0.946759     │
    │ 2024-12-29 ┆ A      ┆ 0.286836 ┆ 0.726047     │
    │ 2024-12-30 ┆ A      ┆ 0.78191  ┆ 0.684947     │
    └────────────┴────────┴──────────┴──────────────┘