python-polarsrust-polars

How to reset cum_sum value based on condition in polars


I have the following DF on which i am trying to construct a new column which is will be cum_sum.

import polars as pl
pl.Config(tbl_rows=20)
df= pl.LazyFrame({
    "a" : [0,1,2,3,4,5,6,7,8,9],
    "b" : [33,35,33,32,33,37,33,29,34,36],
    "c" : [33,35,35,35,35,37,37,37,37,37],
})
df = df.with_columns(pl.when(pl.col('b') == pl.col('c')).then(0).otherwise(1).alias('temp'))
df = df.with_columns(pl.when(pl.col('temp') == 0).then(0).otherwise(pl.col('temp').cum_sum()).alias('c_sum_n'))

Currently i am getting the following output ,

>>> df.collect()
shape: (10, 5)
┌─────┬─────┬─────┬──────┬─────────┐
│ a   ┆ b   ┆ c   ┆ temp ┆ c_sum_n │
│ --- ┆ --- ┆ --- ┆ ---  ┆ ---     │
│ i64 ┆ i64 ┆ i64 ┆ i32  ┆ i32     │
╞═════╪═════╪═════╪══════╪═════════╡
│ 0   ┆ 33  ┆ 33  ┆ 0    ┆ 0       │
│ 1   ┆ 35  ┆ 35  ┆ 0    ┆ 0       │
│ 2   ┆ 33  ┆ 35  ┆ 1    ┆ 1       │
│ 3   ┆ 32  ┆ 35  ┆ 1    ┆ 2       │
│ 4   ┆ 33  ┆ 35  ┆ 1    ┆ 3       │
│ 5   ┆ 37  ┆ 37  ┆ 0    ┆ 0       │
│ 6   ┆ 33  ┆ 37  ┆ 1    ┆ 4       │
│ 7   ┆ 29  ┆ 37  ┆ 1    ┆ 5       │
│ 8   ┆ 34  ┆ 37  ┆ 1    ┆ 6       │
│ 9   ┆ 36  ┆ 37  ┆ 1    ┆ 7       │
└─────┴─────┴─────┴──────┴─────────┘

but the desired output i am looking is

>>> df.collect()
shape: (10, 5)
┌─────┬─────┬─────┬──────┬─────────┐
│ a   ┆ b   ┆ c   ┆ temp ┆ c_sum_n │
│ --- ┆ --- ┆ --- ┆ ---  ┆ ---     │
│ i64 ┆ i64 ┆ i64 ┆ i32  ┆ i32     │
╞═════╪═════╪═════╪══════╪═════════╡
│ 0   ┆ 33  ┆ 33  ┆ 0    ┆ 0       │
│ 1   ┆ 35  ┆ 35  ┆ 0    ┆ 0       │
│ 2   ┆ 33  ┆ 35  ┆ 1    ┆ 1       │
│ 3   ┆ 32  ┆ 35  ┆ 1    ┆ 2       │
│ 4   ┆ 33  ┆ 35  ┆ 1    ┆ 3       │
│ 5   ┆ 37  ┆ 37  ┆ 0    ┆ 0       │
│ 6   ┆ 33  ┆ 37  ┆ 1    ┆ 1       │
│ 7   ┆ 29  ┆ 37  ┆ 1    ┆ 2       │
│ 8   ┆ 34  ┆ 37  ┆ 1    ┆ 3       │
│ 9   ┆ 36  ┆ 37  ┆ 1    ┆ 4       │
└─────┴─────┴─────┴──────┴─────────┘

i,e when col b == col c, i would like to have cum_sum value reset to zero and start all over again. Any help, either in python polars or rust polars will be helpful.


Solution

  • Here's one way to do this using over:

    We extract the predicate col("b") == col("c"), then we do a cum_sum of ~predicate over the cum_sum of predicate.

    import polars as pl
    
    df = pl.DataFrame(
        {
            "a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            "b": [33, 35, 33, 32, 33, 37, 33, 29, 34, 36],
            "c": [33, 35, 35, 35, 35, 37, 37, 37, 37, 37],
        }
    )
    
    predicate = pl.col("b") == pl.col("c")
    
    df = df.with_columns((~predicate).cum_sum().over(predicate.cum_sum()).alias("c_sum_n"))
    
    print(df)
    

    Output:

    shape: (10, 4)
    ┌─────┬─────┬─────┬─────────┐
    │ a   ┆ b   ┆ c   ┆ c_sum_n │
    │ --- ┆ --- ┆ --- ┆ ---     │
    │ i64 ┆ i64 ┆ i64 ┆ u32     │
    ╞═════╪═════╪═════╪═════════╡
    │ 0   ┆ 33  ┆ 33  ┆ 0       │
    │ 1   ┆ 35  ┆ 35  ┆ 0       │
    │ 2   ┆ 33  ┆ 35  ┆ 1       │
    │ 3   ┆ 32  ┆ 35  ┆ 2       │
    │ 4   ┆ 33  ┆ 35  ┆ 3       │
    │ 5   ┆ 37  ┆ 37  ┆ 0       │
    │ 6   ┆ 33  ┆ 37  ┆ 1       │
    │ 7   ┆ 29  ┆ 37  ┆ 2       │
    │ 8   ┆ 34  ┆ 37  ┆ 3       │
    │ 9   ┆ 36  ┆ 37  ┆ 4       │
    └─────┴─────┴─────┴─────────┘
    

    How this works:

    Doing a cum_sum on the predicate gives us different values for each group which is exactly what we need to split the cum_sum over ~predicate:

    print(df.with_columns(group=predicate.cum_sum()))
    

    Output:

    ┌─────┬─────┬─────┬───────┐
    │ a   ┆ b   ┆ c   ┆ group │
    │ --- ┆ --- ┆ --- ┆ ---   │
    │ i64 ┆ i64 ┆ i64 ┆ u32   │
    ╞═════╪═════╪═════╪═══════╡
    │ 0   ┆ 33  ┆ 33  ┆ 1     │
    │ 1   ┆ 35  ┆ 35  ┆ 2     │
    │ 2   ┆ 33  ┆ 35  ┆ 2     │
    │ 3   ┆ 32  ┆ 35  ┆ 2     │
    │ 4   ┆ 33  ┆ 35  ┆ 2     │
    │ 5   ┆ 37  ┆ 37  ┆ 3     │
    │ 6   ┆ 33  ┆ 37  ┆ 3     │
    │ 7   ┆ 29  ┆ 37  ┆ 3     │
    │ 8   ┆ 34  ┆ 37  ┆ 3     │
    │ 9   ┆ 36  ┆ 37  ┆ 3     │
    └─────┴─────┴─────┴───────┘