I have a ~100M rows long data frame containing IDs in different groups. Some of them are wrong (indicated by the 99). I am trying to correct them with a rolling mode window, similar to the code example below. Is there a better way to do this, since rolling_map() is super slow?
import polars as pl
from scipy import stats
def dummy(input):
return stats.mode(input)[0]
df = pl.DataFrame({'group': [10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20],
'id': [1, 1, 99, 1, 1, 2, 2, 3, 3, 99, 3]})
df.with_columns(pl.col('id')
.rolling_map(function=dummy,
window_size=3,
min_periods=1,
center=True)
.over('group')
.alias('id_mode'))
shape: (11, 3)
╭───────┬─────┬─────────╮
│ group ┆ id ┆ id_mode │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═════╪═════════╡
│ 10 ┆ 1 ┆ 1 │
│ 10 ┆ 1 ┆ 1 │
│ 10 ┆ 99 ┆ 1 │
│ 10 ┆ 1 ┆ 1 │
│ 10 ┆ 1 ┆ 1 │
│ 10 ┆ 2 ┆ 2 │
│ 10 ┆ 2 ┆ 2 │
│ 20 ┆ 3 ┆ 3 │
│ 20 ┆ 3 ┆ 3 │
│ 20 ┆ 99 ┆ 3 │
│ 20 ┆ 3 ┆ 3 │
╰───────┴─────┴─────────╯
There are frame level .rolling()
methods which stay in "expression land".
(df.with_row_index()
.rolling(group_by="group", index_column="index", period="3i")
.agg(pl.col("id").mode().first())
)
shape: (11, 3)
┌───────┬───────┬─────┐
│ group ┆ index ┆ id │
│ --- ┆ --- ┆ --- │
│ i64 ┆ u32 ┆ i64 │
╞═══════╪═══════╪═════╡
│ 10 ┆ 0 ┆ 1 │
│ 10 ┆ 1 ┆ 1 │
│ 10 ┆ 2 ┆ 1 │
│ 10 ┆ 3 ┆ 1 │
│ 10 ┆ 4 ┆ 1 │
│ 10 ┆ 5 ┆ 1 │
│ 10 ┆ 6 ┆ 2 │
│ 20 ┆ 7 ┆ 3 │
│ 20 ┆ 8 ┆ 3 │
│ 20 ┆ 9 ┆ 3 │
│ 20 ┆ 10 ┆ 3 │
└───────┴───────┴─────┘