pythonpython-polars

Polars pl.when().then().otherwise() in conjunction with first row of group_by object


I have a pl.DataFrame with some columns: level_0, symbol, signal, and a trade. The trade column simply indicates whether to buy or sell the respective symbol ("A" and "B"). It's computed over(["level_0", "symbol"]).

import polars as pl

df = pl.DataFrame(
    {
        "level_0": [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        "symbol": [ "A", "A", "A", "A", "B", "B", "B", "B", "A", "A", "A", "A", "B", "B", "B", "B", ],
        "signal": [1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0],
    }
).with_columns(
    pl.col("signal")
    .diff()
    .replace(old=0, new=None)
    .over(["level_0", "symbol"])
    .alias("trade")
)

shape: (16, 4)
┌─────────┬────────┬────────┬───────┐
│ level_0 ┆ symbol ┆ signal ┆ trade │
│ ---     ┆ ---    ┆ ---    ┆ ---   │
│ i64     ┆ str    ┆ i64    ┆ i64   │
╞═════════╪════════╪════════╪═══════╡
│ 0       ┆ A      ┆ 1      ┆ null  │
│ 0       ┆ A      ┆ 0      ┆ -1    │
│ 0       ┆ A      ┆ 1      ┆ 1     │
│ 0       ┆ A      ┆ 1      ┆ null  │
│ 0       ┆ B      ┆ 0      ┆ null  │
│ 0       ┆ B      ┆ 1      ┆ 1     │
│ 0       ┆ B      ┆ 1      ┆ null  │
│ 0       ┆ B      ┆ 0      ┆ -1    │
│ 1       ┆ A      ┆ 0      ┆ null  │
│ 1       ┆ A      ┆ 0      ┆ null  │
│ 1       ┆ A      ┆ 0      ┆ null  │
│ 1       ┆ A      ┆ 1      ┆ 1     │
│ 1       ┆ B      ┆ 1      ┆ null  │
│ 1       ┆ B      ┆ 1      ┆ null  │
│ 1       ┆ B      ┆ 0      ┆ -1    │
│ 1       ┆ B      ┆ 0      ┆ null  │
└─────────┴────────┴────────┴───────┘

So far, so good. The only thing is that the first row of each group (["level_0", "symbol"]) isn't correct. I would like to change the null values in the trade column according to the following rule:

To put it differently, I am looking to change the null values in the first row of each group in the trade column whenever there is a value in the signal column that is different from zero.

Here's what I'm looking for:

shape: (16, 4)
┌─────────┬────────┬────────┬───────┐
│ level_0 ┆ symbol ┆ signal ┆ trade │
│ ---     ┆ ---    ┆ ---    ┆ ---   │
│ i64     ┆ str    ┆ i64    ┆ i64   │
╞═════════╪════════╪════════╪═══════╡
│ 0       ┆ A      ┆ 1      ┆ 1     │ <-- copied value from signal column
│ 0       ┆ A      ┆ 0      ┆ -1    │
│ 0       ┆ A      ┆ 1      ┆ 1     │
│ 0       ┆ A      ┆ 1      ┆ null  │
│ 0       ┆ B      ┆ 0      ┆ null  │ <-- value stays unchanged
│ 0       ┆ B      ┆ 1      ┆ 1     │
│ 0       ┆ B      ┆ 1      ┆ null  │
│ 0       ┆ B      ┆ 0      ┆ -1    │
│ 1       ┆ A      ┆ 0      ┆ null  │ <-- value stays unchanged
│ 1       ┆ A      ┆ 0      ┆ null  │
│ 1       ┆ A      ┆ 0      ┆ null  │
│ 1       ┆ A      ┆ 1      ┆ 1     │
│ 1       ┆ B      ┆ 1      ┆ 1     │ <-- copied value from signal column
│ 1       ┆ B      ┆ 1      ┆ null  │
│ 1       ┆ B      ┆ 0      ┆ -1    │
│ 1       ┆ B      ┆ 0      ┆ null  │
└─────────┴────────┴────────┴───────┘

Solution

  • You can create a "group key" from multiple columns/expressions using a pl.struct()

    .is_first_distinct() can then be used to identify the first row of each "group".

    with pl.Config(tbl_rows=16):
        df.with_columns(
           pl.when(pl.struct("level_0", "symbol").is_first_distinct(), pl.col.signal == 1)
             .then(1)
             .otherwise(pl.col.trade)
             .alias("trade")
        )
    
    shape: (16, 4)
    ┌─────────┬────────┬────────┬───────┐
    │ level_0 ┆ symbol ┆ signal ┆ trade │
    │ ---     ┆ ---    ┆ ---    ┆ ---   │
    │ i64     ┆ str    ┆ i64    ┆ i64   │
    ╞═════════╪════════╪════════╪═══════╡
    │ 0       ┆ A      ┆ 1      ┆ 1     │
    │ 0       ┆ A      ┆ 0      ┆ -1    │
    │ 0       ┆ A      ┆ 1      ┆ 1     │
    │ 0       ┆ A      ┆ 1      ┆ null  │
    │ 0       ┆ B      ┆ 0      ┆ null  │
    │ 0       ┆ B      ┆ 1      ┆ 1     │
    │ 0       ┆ B      ┆ 1      ┆ null  │
    │ 0       ┆ B      ┆ 0      ┆ -1    │
    │ 1       ┆ A      ┆ 0      ┆ null  │
    │ 1       ┆ A      ┆ 0      ┆ null  │
    │ 1       ┆ A      ┆ 0      ┆ null  │
    │ 1       ┆ A      ┆ 1      ┆ 1     │
    │ 1       ┆ B      ┆ 1      ┆ 1     │
    │ 1       ┆ B      ┆ 1      ┆ null  │
    │ 1       ┆ B      ┆ 0      ┆ -1    │
    │ 1       ┆ B      ┆ 0      ┆ null  │
    └─────────┴────────┴────────┴───────┘
    

    Update per follow-up question:

    check if signal in the first row of each group is different from zero, but I then replace the trade value of the following row

    You can .shift() over groups to move forward/backward by N rows.

    df.with_columns(
       pl.all_horizontal(
          pl.struct("level_0", "symbol").is_first_distinct(),
          pl.col("signal") == 1
       )
    #   .shift()
    #   .over("level_0", "symbol")
       .alias("change me")
    )
    
    shape: (16, 5)
    ┌─────────┬────────┬────────┬───────┬───────────┐
    │ level_0 ┆ symbol ┆ signal ┆ trade ┆ change me │
    │ ---     ┆ ---    ┆ ---    ┆ ---   ┆ ---       │
    │ i64     ┆ str    ┆ i64    ┆ i64   ┆ bool      │
    ╞═════════╪════════╪════════╪═══════╪═══════════╡
    │ 0       ┆ A      ┆ 1      ┆ null  ┆ true      │
    │ 0       ┆ A      ┆ 0      ┆ -1    ┆ false     │ # <- shift to here
    │ 0       ┆ A      ┆ 1      ┆ 1     ┆ false     │
    │ 0       ┆ A      ┆ 1      ┆ null  ┆ false     │
    │ 0       ┆ B      ┆ 0      ┆ null  ┆ false     │
    │ 0       ┆ B      ┆ 1      ┆ 1     ┆ false     │ # <- shift to here
    │ 0       ┆ B      ┆ 1      ┆ null  ┆ false     │
    │ 0       ┆ B      ┆ 0      ┆ -1    ┆ false     │
    │ 1       ┆ A      ┆ 0      ┆ null  ┆ false     │
    │ 1       ┆ A      ┆ 0      ┆ null  ┆ false     │ # <- shift to here
    │ 1       ┆ A      ┆ 0      ┆ null  ┆ false     │
    │ 1       ┆ A      ┆ 1      ┆ 1     ┆ false     │
    │ 1       ┆ B      ┆ 1      ┆ null  ┆ true      │
    │ 1       ┆ B      ┆ 1      ┆ null  ┆ false     │ # <- shift to here
    │ 1       ┆ B      ┆ 0      ┆ -1    ┆ false     │
    │ 1       ┆ B      ┆ 0      ┆ null  ┆ false     │
    └─────────┴────────┴────────┴───────┴───────────┘
    

    Used within when/then:

    df.with_columns(
       pl.when(
          pl.all_horizontal(
             pl.struct("level_0", "symbol").is_first_distinct(),
             pl.col("signal") == 1
          )
          .shift()
          .over("level_0", "symbol")
       )
       .then(1)
       .otherwise(pl.col("trade"))
       .alias("new_trade")
    )
    
    shape: (16, 5)
    ┌─────────┬────────┬────────┬───────┬───────────┐
    │ level_0 ┆ symbol ┆ signal ┆ trade ┆ new_trade │
    │ ---     ┆ ---    ┆ ---    ┆ ---   ┆ ---       │
    │ i64     ┆ str    ┆ i64    ┆ i64   ┆ i64       │
    ╞═════════╪════════╪════════╪═══════╪═══════════╡
    │ 0       ┆ A      ┆ 1      ┆ null  ┆ null      │
    │ 0       ┆ A      ┆ 0      ┆ -1    ┆ 1         │ # CHANGE
    │ 0       ┆ A      ┆ 1      ┆ 1     ┆ 1         │
    │ 0       ┆ A      ┆ 1      ┆ null  ┆ null      │
    │ 0       ┆ B      ┆ 0      ┆ null  ┆ null      │
    │ 0       ┆ B      ┆ 1      ┆ 1     ┆ 1         │
    │ 0       ┆ B      ┆ 1      ┆ null  ┆ null      │
    │ 0       ┆ B      ┆ 0      ┆ -1    ┆ -1        │
    │ 1       ┆ A      ┆ 0      ┆ null  ┆ null      │
    │ 1       ┆ A      ┆ 0      ┆ null  ┆ null      │
    │ 1       ┆ A      ┆ 0      ┆ null  ┆ null      │
    │ 1       ┆ A      ┆ 1      ┆ 1     ┆ 1         │
    │ 1       ┆ B      ┆ 1      ┆ null  ┆ null      │
    │ 1       ┆ B      ┆ 1      ┆ null  ┆ 1         │ # CHANGE
    │ 1       ┆ B      ┆ 0      ┆ -1    ┆ -1        │
    │ 1       ┆ B      ┆ 0      ┆ null  ┆ null      │
    └─────────┴────────┴────────┴───────┴───────────┘