Using polars .rolling and .agg, how do I get the original column back, without having to join back with the original column, or without having to use .over?
Example:
import polars as pl
dates = [
"2020-01-01 13:45:48",
"2020-01-01 16:42:13",
"2020-01-01 16:45:09",
"2020-01-02 18:12:48",
"2020-01-03 19:45:32",
"2020-01-08 23:16:43",
]
df = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}).with_columns(
pl.col("dt").str.to_datetime().set_sorted()
)
Provides me with a small polars dataframe:
┌─────────────────────┬─────┐
│ dt ┆ a │
│ --- ┆ --- │
│ datetime[μs] ┆ i64 │
╞═════════════════════╪═════╡
│ 2020-01-01 13:45:48 ┆ 3 │
│ 2020-01-01 16:42:13 ┆ 7 │
│ 2020-01-01 16:45:09 ┆ 5 │
│ 2020-01-02 18:12:48 ┆ 9 │
│ 2020-01-03 19:45:32 ┆ 2 │
│ 2020-01-08 23:16:43 ┆ 1 │
└─────────────────────┴─────┘
When I apply a rolling aggregations, I get the new columns back, but not the original columns:
out = df.rolling(index_column="dt", period="2d").agg(
pl.sum("a").alias("sum_a"),
pl.min("a").alias("min_a"),
pl.max("a").alias("max_a"),
pl.col("a")
)
which gives:
┌─────────────────────┬───────┬───────┬───────┬──────────────┐
│ dt ┆ sum_a ┆ min_a ┆ max_a ┆ a │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ list[i64] │
╞═════════════════════╪═══════╪═══════╪═══════╪══════════════╡
│ 2020-01-01 13:45:48 ┆ 3 ┆ 3 ┆ 3 ┆ [3] │
│ 2020-01-01 16:42:13 ┆ 10 ┆ 3 ┆ 7 ┆ [3, 7] │
│ 2020-01-01 16:45:09 ┆ 15 ┆ 3 ┆ 7 ┆ [3, 7, 5] │
│ 2020-01-02 18:12:48 ┆ 24 ┆ 3 ┆ 9 ┆ [3, 7, 5, 9] │
│ 2020-01-03 19:45:32 ┆ 11 ┆ 2 ┆ 9 ┆ [9, 2] │
│ 2020-01-08 23:16:43 ┆ 1 ┆ 1 ┆ 1 ┆ [1] │
└─────────────────────┴───────┴───────┴───────┴──────────────┘
How can I get the original a column. I don't want to join and I don't want to use .over as I need the group_by of the rolling later on and .over does not work with .rolling
Edit. I am also not keen on using the following.
out = df.rolling(index_column="dt", period="2d").agg(
pl.sum("a").alias("sum_a"),
pl.min("a").alias("min_a"),
pl.max("a").alias("max_a"),
pl.col("a").last()
)
Edit 2. Why Expr.rolling() is not feasible and why I need the group_by:
Given a more elaborate example:
dates = [
"2020-01-01 13:45:48",
"2020-01-01 16:42:13",
"2020-01-01 16:45:09",
"2020-01-02 18:12:48",
"2020-01-03 19:45:32",
"2020-01-08 23:16:43",
]
df_a = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1],"cat":["one"]*6}).with_columns(
pl.col("dt").str.to_datetime()
)
df_b = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1],"cat":["two"]*6}).with_columns(
pl.col("dt").str.to_datetime()
)
df = pl.concat([df_a,df_b])
┌─────────────────────┬─────┬─────┐
│ dt ┆ a ┆ cat │
│ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ i64 ┆ str │
╞═════════════════════╪═════╪═════╡
│ 2020-01-01 13:45:48 ┆ 3 ┆ one │
│ 2020-01-01 16:42:13 ┆ 7 ┆ one │
│ 2020-01-01 16:45:09 ┆ 5 ┆ one │
│ 2020-01-02 18:12:48 ┆ 9 ┆ one │
│ 2020-01-03 19:45:32 ┆ 2 ┆ one │
│ 2020-01-08 23:16:43 ┆ 1 ┆ one │
│ 2020-01-01 13:45:48 ┆ 3 ┆ two │
│ 2020-01-01 16:42:13 ┆ 7 ┆ two │
│ 2020-01-01 16:45:09 ┆ 5 ┆ two │
│ 2020-01-02 18:12:48 ┆ 9 ┆ two │
│ 2020-01-03 19:45:32 ┆ 2 ┆ two │
│ 2020-01-08 23:16:43 ┆ 1 ┆ two │
└─────────────────────┴─────┴─────┘
and the code:
out = df.rolling(index_column="dt", period="2d",group_by="cat").agg(
pl.sum("a").alias("sum_a"),
pl.min("a").alias("min_a"),
pl.max("a").alias("max_a"),
pl.col("a")
)
┌─────┬─────────────────────┬───────┬───────┬───────┬──────────────┐
│ cat ┆ dt ┆ sum_a ┆ min_a ┆ max_a ┆ a │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ list[i64] │
╞═════╪═════════════════════╪═══════╪═══════╪═══════╪══════════════╡
│ one ┆ 2020-01-01 13:45:48 ┆ 3 ┆ 3 ┆ 3 ┆ [3] │
│ one ┆ 2020-01-01 16:42:13 ┆ 10 ┆ 3 ┆ 7 ┆ [3, 7] │
│ one ┆ 2020-01-01 16:45:09 ┆ 15 ┆ 3 ┆ 7 ┆ [3, 7, 5] │
│ one ┆ 2020-01-02 18:12:48 ┆ 24 ┆ 3 ┆ 9 ┆ [3, 7, 5, 9] │
│ one ┆ 2020-01-03 19:45:32 ┆ 11 ┆ 2 ┆ 9 ┆ [9, 2] │
│ one ┆ 2020-01-08 23:16:43 ┆ 1 ┆ 1 ┆ 1 ┆ [1] │
│ two ┆ 2020-01-01 13:45:48 ┆ 3 ┆ 3 ┆ 3 ┆ [3] │
│ two ┆ 2020-01-01 16:42:13 ┆ 10 ┆ 3 ┆ 7 ┆ [3, 7] │
│ two ┆ 2020-01-01 16:45:09 ┆ 15 ┆ 3 ┆ 7 ┆ [3, 7, 5] │
│ two ┆ 2020-01-02 18:12:48 ┆ 24 ┆ 3 ┆ 9 ┆ [3, 7, 5, 9] │
│ two ┆ 2020-01-03 19:45:32 ┆ 11 ┆ 2 ┆ 9 ┆ [9, 2] │
│ two ┆ 2020-01-08 23:16:43 ┆ 1 ┆ 1 ┆ 1 ┆ [1] │
└─────┴─────────────────────┴───────┴───────┴───────┴──────────────┘
This does not work:
df.sort("dt").with_columns(sum=pl.sum("a").rolling(index_column="dt", period="2d").over("cat"))
Gives:
# InvalidOperationError: rolling expression not allowed in aggregation
There are dedicated rolling_*_by expressions which can be used with .over()
df.with_columns(
pl.col("a").rolling_sum_by("dt", "2d").over("cat").name.prefix("sum_"),
pl.col("a").rolling_min_by("dt", "2d").over("cat").name.prefix("min_"),
pl.col("a").rolling_max_by("dt", "2d").over("cat").name.prefix("max_")
)
shape: (12, 6)
┌─────────────────────┬─────┬─────┬───────┬───────┬───────┐
│ dt ┆ a ┆ cat ┆ sum_a ┆ min_a ┆ max_a │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ datetime[μs] ┆ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 │
╞═════════════════════╪═════╪═════╪═══════╪═══════╪═══════╡
│ 2020-01-01 13:45:48 ┆ 3 ┆ one ┆ 3 ┆ 3 ┆ 3 │
│ 2020-01-01 16:42:13 ┆ 7 ┆ one ┆ 10 ┆ 3 ┆ 7 │
│ 2020-01-01 16:45:09 ┆ 5 ┆ one ┆ 15 ┆ 3 ┆ 7 │
│ 2020-01-02 18:12:48 ┆ 9 ┆ one ┆ 24 ┆ 3 ┆ 9 │
│ 2020-01-03 19:45:32 ┆ 2 ┆ one ┆ 11 ┆ 2 ┆ 9 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 2020-01-01 16:42:13 ┆ 7 ┆ two ┆ 10 ┆ 3 ┆ 7 │
│ 2020-01-01 16:45:09 ┆ 5 ┆ two ┆ 15 ┆ 3 ┆ 7 │
│ 2020-01-02 18:12:48 ┆ 9 ┆ two ┆ 24 ┆ 3 ┆ 9 │
│ 2020-01-03 19:45:32 ┆ 2 ┆ two ┆ 11 ┆ 2 ┆ 9 │
│ 2020-01-08 23:16:43 ┆ 1 ┆ two ┆ 1 ┆ 1 ┆ 1 │
└─────────────────────┴─────┴─────┴───────┴───────┴───────┘