Showing a toy example with K=2 but the question is mostly relevant for high g cardinality and K>>1:
df = pl.DataFrame(dict(
g=[1, 2, 1, 2, 1, 2],
v=[1, 2, 3, 4, 5, 6],
))
K = 2
df.with_columns((col.v.shift(k+1).over('g').alias(f's{k}') for k in range(K)))
╭─────┬─────┬──────┬──────╮
│ g ┆ v ┆ s0 ┆ s1 │
│ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪══════╪══════╡
│ 1 ┆ 1 ┆ null ┆ null │
│ 2 ┆ 2 ┆ null ┆ null │
│ 1 ┆ 3 ┆ 1 ┆ null │
│ 2 ┆ 4 ┆ 2 ┆ null │
│ 1 ┆ 5 ┆ 3 ┆ 1 │
│ 2 ┆ 6 ┆ 4 ┆ 2 │
╰─────┴─────┴──────┴──────╯
How can I make sure the grouping by g is done only once?
Polars does not seem to optimize for this in the query plan.
I would expect it to run as fast as:
df.group_by('g').agg((col.v.shift(k+1).alias(f's{k}') for k in range(K)))
Polars caches window expressions. While you may not see this represented in the query plan, the over grouping is only done once.
Still, your over query will not run as fast as your group_by query. This is because over has to add the results back into the original data frame.
A fairer comparison would be the query below, which will match the result of the over query. As you can see, an additional join is required.
import polars as pl
df = pl.DataFrame(
{
"g": [1, 2, 1, 2, 1, 2],
"v": [1, 2, 3, 4, 5, 6],
}
)
K = 2
df_shift = (
df.group_by("g")
.agg([pl.col("v")] + [pl.col("v").shift(k + 1).alias(f"s{k}") for k in range(K)])
.explode(["v"] + [f"s{k}" for k in range(K)])
)
result = df.join(df_shift, on=["g", "v"], how="left")
print(result)
shape: (6, 4)
┌─────┬─────┬──────┬──────┐
│ g ┆ v ┆ s0 ┆ s1 │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪══════╪══════╡
│ 1 ┆ 1 ┆ null ┆ null │
│ 2 ┆ 2 ┆ null ┆ null │
│ 1 ┆ 3 ┆ 1 ┆ null │
│ 2 ┆ 4 ┆ 2 ┆ null │
│ 1 ┆ 5 ┆ 3 ┆ 1 │
│ 2 ┆ 6 ┆ 4 ┆ 2 │
└─────┴─────┴──────┴──────┘