I recently came across some behaviour in the way that some of the Polars rolling functions work that I don't understand. The problem seems to only present itself when there is a NaN
(np.nan
) as well as a null
(None
) in the data.
I am on polars version 1.31.0
, numpy 2.3.1
and python 3.12.11
.
Here are some examples, note that I will be using the rolling_sum
function with a window size of 2 in all examples:
First, a non-problematic application with expected behaviour
data_dict_1 = {"x":[1., 1., 1., np.nan, 1., 1., 1., 1., 1.]}
This data contains a NaN
value and so I would expect that the rolling_sum
function would have a null
at index position 0 and two NaN
s in the index positions 3 and 4, indeed this is what we see.
data_1 = pl.DataFrame(data_dict_1)
with pl.Config(tbl_rows=20):
print(data_1.with_columns(pl.col("x").rolling_sum(2).alias("rolling")))
shape: (9, 2)
┌─────┬─────────┐
│ x ┆ rolling │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞═════╪═════════╡
│ 1.0 ┆ null │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ NaN ┆ NaN │
│ 1.0 ┆ NaN │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
└─────┴─────────┘
But if a null
(None
) value is now added into the data, the behaviour deviates from what I would expect.
data_dict_2 = {"x":[1., 1., 1., np.nan, 1., 1., 1., 1., 1., None, 1., 1., 1.]}
Now I would expect a null
at index position 0, two NaN
s in the index positions 3 and 4, and two null
s in the index positions 9 and 10. But instead the result I obtain seems to have some weird propagation of the NaN
s.
data_2 = pl.DataFrame(data_dict_2)
with pl.Config(tbl_rows=20):
print(data_2.with_columns(pl.col("x").rolling_sum(2).alias("rolling")))
shape: (13, 2)
┌──────┬─────────┐
│ x ┆ rolling │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞══════╪═════════╡
│ 1.0 ┆ null │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ NaN ┆ NaN │
│ 1.0 ┆ NaN │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ NaN │
│ 1.0 ┆ NaN │
│ 1.0 ┆ NaN │
│ null ┆ null │
│ 1.0 ┆ null │
│ 1.0 ┆ NaN │
│ 1.0 ┆ NaN │
└──────┴─────────┘
Notice in particular that we only get one 'normal' value after the NaN
(i.e. the 2.0), from then onwards it doesn't output what is expected (aside from the null
s).
For comparison, by using the rolling_map
function with the sum
function I can obtain my desired result and what I would have expected the rolling_sum
function to do.
with pl.Config(tbl_rows=20):
print(data_2.with_columns(pl.col("x").rolling_map(sum, 2).alias("rolling")))
shape: (13, 2)
┌──────┬─────────┐
│ x ┆ rolling │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞══════╪═════════╡
│ 1.0 ┆ null │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ NaN ┆ NaN │
│ 1.0 ┆ NaN │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ null ┆ null │
│ 1.0 ┆ null │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
└──────┴─────────┘
In my testing of this example it does not seem to matter where the null
value is placed in the data, and it will still result in what I deem to be unexpected results.
Of course, in this small example a simple solution might be to replace all the NaN
s with null
s and then things should behave as I expect (of course the NaN
s in the output would be null
s but at least the actual values would be there), or indeed to just use the rolling_map
with sum
.
But I believe (correct me if I'm wrong) that using rolling_map
is generally slower than using an in-built Polars function like rolling_sum
.
Also, the dataset on which I originally encountered this problem is more complicated than this example and the simple fix of replacing NaN
s with null
s did not seem to help, I could not produce a simple dataset for which the problem persists after NaN
replacement with null
s for here and so I am hoping instead someone could shed some light on why this behaviour occurs on this simple dataset so I can try debug my bigger dataset analysis from there.
There is a merged PR on main that fixes this for the next release:
data_2 = pl.DataFrame(data_dict_2)
with pl.Config(tbl_rows=20):
print(data_2.with_columns(pl.col("x").rolling_sum(2).alias("rolling")))
┌──────┬─────────┐
│ x ┆ rolling │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞══════╪═════════╡
│ 1.0 ┆ null │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ NaN ┆ NaN │
│ 1.0 ┆ NaN │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
│ null ┆ null │
│ 1.0 ┆ null │
│ 1.0 ┆ 2.0 │
│ 1.0 ┆ 2.0 │
└──────┴─────────┘
And yes, rolling_map()
runs a Python UDF and materializes Series objects - so it has significant overhead.
The docs contain a warning:
Computing custom functions is extremely slow. Use specialized rolling functions such as
Expr.rolling_sum()
if at all possible.