I received an unknown error in Python Polars:
thread '<unnamed>' panicked at 'assertion failed: `(left == right)`
left: `Float64[NaN, 1, NaN, NaN, NaN, ...[clip]...
right: `Float64[NaN, 1, NaN, NaN, NaN, ...[clip]...
Is this an internal error?
The code that triggers it is:
df.select([
pl.col('total').shift().ewm_mean(half_life = 10).over('group')
])
It's hard for me to ask more because the error is so inscrutable.
Edit: this issue is now fixed in Polars 0.13.19
and above, and a workaround is no longer needed.
Another temporary way to work around this is to create the result of shift
with an over
window in another way.
Let's say we have the following groups, numbered observations, and totals.
import numpy as np
import polars as pl
df = pl.DataFrame(
{
"group": ["a", "a", "b", "a", "b", "b"],
"obs": [1, 2, 1, 3, 2, 3],
"total": [1.0, 2, 3, 4, 5, np.NaN],
}
)
df
shape: (6, 3)
┌───────┬─────┬───────┐
│ group ┆ obs ┆ total │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 │
╞═══════╪═════╪═══════╡
│ a ┆ 1 ┆ 1.0 │
│ a ┆ 2 ┆ 2.0 │
│ b ┆ 1 ┆ 3.0 │
│ a ┆ 3 ┆ 4.0 │
│ b ┆ 2 ┆ 5.0 │
│ b ┆ 3 ┆ NaN │
└───────┴─────┴───────┘
The following code will arrive at the same result as the shift
over the groups:
df = (
df.sort("group", "obs")
.with_columns(pl.col("total").shift().alias("total_shifted"))
.with_columns(
pl.when(pl.col("group").is_first())
.then(None)
.otherwise(pl.col("total_shifted"))
.alias("result")
)
)
df
shape: (6, 5)
┌───────┬─────┬───────┬───────────────┬────────┐
│ group ┆ obs ┆ total ┆ total_shifted ┆ result │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 │
╞═══════╪═════╪═══════╪═══════════════╪════════╡
│ a ┆ 1 ┆ 1.0 ┆ null ┆ null │
│ a ┆ 2 ┆ 2.0 ┆ 1.0 ┆ 1.0 │
│ a ┆ 3 ┆ 4.0 ┆ 2.0 ┆ 2.0 │
│ b ┆ 1 ┆ 3.0 ┆ 4.0 ┆ null │
│ b ┆ 2 ┆ 5.0 ┆ 3.0 ┆ 3.0 │
│ b ┆ 3 ┆ NaN ┆ 5.0 ┆ 5.0 │
└───────┴─────┴───────┴───────────────┴────────┘
(I've left the intermediate calculations in the dataset for inspection, to show how the algorithm works.)
Notice that the result
column is the same value you'd obtained from a shift
over groups. You can then run your aggregations on the result
column, without the need for using shift.
df.select(
pl.col('result').ewm_mean(half_life = 10).over('group')
)
Of course, you'll have to adapt this to your particular code, but it should work.