pythonpython-polars

Polars Rolling Mean, fill start of window with null instead of shortened window


My question is whether there is a way to have null until the full window can be filled at the start of a rolling window in polars. For example:

dates = [
    "2020-01-01",
    "2020-01-02",
    "2020-01-03",
    "2020-01-04",
    "2020-01-05",
    "2020-01-06",
    "2020-01-01",
    "2020-01-02",
    "2020-01-03",
    "2020-01-04",
    "2020-01-05",
    "2020-01-06",
]
df = pl.DataFrame({"dt": dates, "a": [3, 4, 2, 8, 10, 1, 1, 7, 5, 9, 2, 1], "b": ["Yes","Yes","Yes","Yes","Yes", "Yes", "No", "No", "No", "No", "No", "No"]}).with_columns(
    pl.col("dt").str.strptime(pl.Date).set_sorted()
)
df = df.sort(by = 'dt')

df.rolling(
    index_column="dt", period="2d", group_by = 'b'
).agg(pl.col("a").mean().alias("ma_2d"))

Result

b   dt  ma_2d
str date    f64
"Yes"   2020-01-01  3.0
"Yes"   2020-01-02  3.5
"Yes"   2020-01-03  3.0
"Yes"   2020-01-04  5.0
"Yes"   2020-01-05  9.0

My expectation in this case is that the first day should be null because there aren't 2 days to fill the window. But polars seems to just truncate the window to fill the starting days.


Solution

  • Can you check the length?

    (df.rolling(index_column="dt", period="2d", group_by="b")
       .agg(
          pl.when(pl.len() > 1)
            .then(pl.col("a").mean())
            .alias("ma_2d")
       )
    )
    
    shape: (12, 3)
    ┌─────┬────────────┬───────┐
    │ b   ┆ dt         ┆ ma_2d │
    │ --- ┆ ---        ┆ ---   │
    │ str ┆ date       ┆ f64   │
    ╞═════╪════════════╪═══════╡
    │ Yes ┆ 2020-01-01 ┆ null  │
    │ Yes ┆ 2020-01-02 ┆ 3.5   │
    │ Yes ┆ 2020-01-03 ┆ 3.0   │
    │ Yes ┆ 2020-01-04 ┆ 5.0   │
    │ Yes ┆ 2020-01-05 ┆ 9.0   │
    │ …   ┆ …          ┆ …     │
    │ No  ┆ 2020-01-02 ┆ 4.0   │
    │ No  ┆ 2020-01-03 ┆ 6.0   │
    │ No  ┆ 2020-01-04 ┆ 7.0   │
    │ No  ┆ 2020-01-05 ┆ 5.5   │
    │ No  ┆ 2020-01-06 ┆ 1.5   │
    └─────┴────────────┴───────┘
    

    Alternatively, there is a dedicated .rolling_mean_by() method that supports min_periods.

    df.with_columns(
       pl.col("a").rolling_mean_by("dt", window_size="2d", min_periods=2)
         .over("b")
         .alias("ma_2d")
    )
    
    shape: (12, 4)
    ┌────────────┬─────┬─────┬───────┐
    │ dt         ┆ a   ┆ b   ┆ ma_2d │
    │ ---        ┆ --- ┆ --- ┆ ---   │
    │ date       ┆ i64 ┆ str ┆ f64   │
    ╞════════════╪═════╪═════╪═══════╡
    │ 2020-01-01 ┆ 3   ┆ Yes ┆ null  │
    │ 2020-01-01 ┆ 1   ┆ No  ┆ null  │
    │ 2020-01-02 ┆ 4   ┆ Yes ┆ 3.5   │
    │ 2020-01-02 ┆ 7   ┆ No  ┆ 4.0   │
    │ 2020-01-03 ┆ 2   ┆ Yes ┆ 3.0   │
    │ …          ┆ …   ┆ …   ┆ …     │
    │ 2020-01-04 ┆ 9   ┆ No  ┆ 7.0   │
    │ 2020-01-05 ┆ 10  ┆ Yes ┆ 9.0   │
    │ 2020-01-05 ┆ 2   ┆ No  ┆ 5.5   │
    │ 2020-01-06 ┆ 1   ┆ Yes ┆ 5.5   │
    │ 2020-01-06 ┆ 1   ┆ No  ┆ 1.5   │
    └────────────┴─────┴─────┴───────┘