pythonpython-polars

Unexpected behaviour of some Polars rolling functions when NaN's and Nulls are together


I recently came across some behaviour in the way that some of the Polars rolling functions work that I don't understand. The problem seems to only present itself when there is a NaN (np.nan) as well as a null (None) in the data.

I am on polars version 1.31.0, numpy 2.3.1 and python 3.12.11.
Here are some examples, note that I will be using the rolling_sum function with a window size of 2 in all examples:

First, a non-problematic application with expected behaviour
data_dict_1 = {"x":[1., 1., 1., np.nan, 1., 1., 1., 1., 1.]}
This data contains a NaN value and so I would expect that the rolling_sum function would have a null at index position 0 and two NaNs in the index positions 3 and 4, indeed this is what we see.

data_1 = pl.DataFrame(data_dict_1)
with pl.Config(tbl_rows=20):
    print(data_1.with_columns(pl.col("x").rolling_sum(2).alias("rolling")))
shape: (9, 2)
┌─────┬─────────┐
│ x   ┆ rolling │
│ --- ┆ ---     │
│ f64 ┆ f64     │
╞═════╪═════════╡
│ 1.0 ┆ null    │
│ 1.0 ┆ 2.0     │
│ 1.0 ┆ 2.0     │
│ NaN ┆ NaN     │
│ 1.0 ┆ NaN     │
│ 1.0 ┆ 2.0     │
│ 1.0 ┆ 2.0     │
│ 1.0 ┆ 2.0     │
│ 1.0 ┆ 2.0     │
└─────┴─────────┘

But if a null (None) value is now added into the data, the behaviour deviates from what I would expect.
data_dict_2 = {"x":[1., 1., 1., np.nan, 1., 1., 1., 1., 1., None, 1., 1., 1.]}
Now I would expect a null at index position 0, two NaNs in the index positions 3 and 4, and two nulls in the index positions 9 and 10. But instead the result I obtain seems to have some weird propagation of the NaNs.

data_2 = pl.DataFrame(data_dict_2)
with pl.Config(tbl_rows=20):
    print(data_2.with_columns(pl.col("x").rolling_sum(2).alias("rolling")))
shape: (13, 2)
┌──────┬─────────┐
│ x    ┆ rolling │
│ ---  ┆ ---     │
│ f64  ┆ f64     │
╞══════╪═════════╡
│ 1.0  ┆ null    │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ 2.0     │
│ NaN  ┆ NaN     │
│ 1.0  ┆ NaN     │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ NaN     │
│ 1.0  ┆ NaN     │
│ 1.0  ┆ NaN     │
│ null ┆ null    │
│ 1.0  ┆ null    │
│ 1.0  ┆ NaN     │
│ 1.0  ┆ NaN     │
└──────┴─────────┘

Notice in particular that we only get one 'normal' value after the NaN (i.e. the 2.0), from then onwards it doesn't output what is expected (aside from the nulls). For comparison, by using the rolling_map function with the sum function I can obtain my desired result and what I would have expected the rolling_sum function to do.

with pl.Config(tbl_rows=20):
    print(data_2.with_columns(pl.col("x").rolling_map(sum, 2).alias("rolling")))
shape: (13, 2)
┌──────┬─────────┐
│ x    ┆ rolling │
│ ---  ┆ ---     │
│ f64  ┆ f64     │
╞══════╪═════════╡
│ 1.0  ┆ null    │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ 2.0     │
│ NaN  ┆ NaN     │
│ 1.0  ┆ NaN     │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ 2.0     │
│ null ┆ null    │
│ 1.0  ┆ null    │
│ 1.0  ┆ 2.0     │
│ 1.0  ┆ 2.0     │
└──────┴─────────┘

In my testing of this example it does not seem to matter where the null value is placed in the data, and it will still result in what I deem to be unexpected results.
Of course, in this small example a simple solution might be to replace all the NaNs with nulls and then things should behave as I expect (of course the NaNs in the output would be nulls but at least the actual values would be there), or indeed to just use the rolling_map with sum.
But I believe (correct me if I'm wrong) that using rolling_map is generally slower than using an in-built Polars function like rolling_sum.
Also, the dataset on which I originally encountered this problem is more complicated than this example and the simple fix of replacing NaNs with nulls did not seem to help, I could not produce a simple dataset for which the problem persists after NaN replacement with nulls for here and so I am hoping instead someone could shed some light on why this behaviour occurs on this simple dataset so I can try debug my bigger dataset analysis from there.


Solution

  • There is a merged PR on main that fixes this for the next release:

    data_2 = pl.DataFrame(data_dict_2)
    
    with pl.Config(tbl_rows=20):
        print(data_2.with_columns(pl.col("x").rolling_sum(2).alias("rolling")))
    
    ┌──────┬─────────┐
    │ x    ┆ rolling │
    │ ---  ┆ ---     │
    │ f64  ┆ f64     │
    ╞══════╪═════════╡
    │ 1.0  ┆ null    │
    │ 1.0  ┆ 2.0     │
    │ 1.0  ┆ 2.0     │
    │ NaN  ┆ NaN     │
    │ 1.0  ┆ NaN     │
    │ 1.0  ┆ 2.0     │
    │ 1.0  ┆ 2.0     │
    │ 1.0  ┆ 2.0     │
    │ 1.0  ┆ 2.0     │
    │ null ┆ null    │
    │ 1.0  ┆ null    │
    │ 1.0  ┆ 2.0     │
    │ 1.0  ┆ 2.0     │
    └──────┴─────────┘
    

    And yes, rolling_map() runs a Python UDF and materializes Series objects - so it has significant overhead.

    The docs contain a warning:

    Computing custom functions is extremely slow. Use specialized rolling functions such as Expr.rolling_sum() if at all possible.