
Split dataframe according to sub-lists by cutoff value

I want to split the dataframe according to the sublists that given from diving a list into parts where the only value above a cutoff is the first.

e.g. Cutoff = 3

[4,2,3,5,2,1,6,7] => [4,2,3], [5,2,1], [6], [7]

I still need to keep track of the other fields in the dataframe.

I should get the given result from this df

data = {
    "uid": ["Alice", "Bob", "Charlie"],
    "time_deltas": [
        [4,2, 3],
        [1,1, 4, 8, 3],
        [1,1, 7, 3, 2],
    "other_field": [["x", "y", "z"], ["x", "y", "z", "x", "y"], ["x", "y", "z", "x", "y"]]

df = pl.DataFrame(data)
cutoff = 3

# Split the time_delta column into lists where the maximum time_delta (excluding the first value) is greater than the cutoff. Ensure that the other_field column is also split accordingly.

# Expected Output
# +--------+----------------------+----------------------+
# | uid    | time_deltas          | other_field          |
# | ---    | ---                  | ---                  |
# | str    | list[duration[ms]]   | list[str]            |
# +--------+----------------------+----------------------+
# | Alice  | [4, 2, 3]            | ["x", "y", "z"]      |
# | Bob    | [1, 1]               | ["x", "y"]           |
# | Bob    | [4]                  | ["z"]                |
# | Bob    | [8, 3]               | ["x", "y"]           |
# | Charlie| [1, 1]               | ["x", "y"]           |
# | Charlie| [7,3,2]              | ["z", "x", "y"]      |


  • If you explode/flatten the lists, you can use .cum_sum() of the comparison .over() each "group" to assign a group id/index to identify each sublist.

    (df.explode("time_deltas", "other_field")
       .with_columns(bool = pl.col("time_deltas") > cutoff)
       .with_columns(index = pl.col("bool").cum_sum().over("uid"))
    shape: (13, 5)
    │ uid     ┆ time_deltas ┆ other_field ┆ bool  ┆ index │
    │ ---     ┆ ---         ┆ ---         ┆ ---   ┆ ---   │
    │ str     ┆ i64         ┆ str         ┆ bool  ┆ u32   │
    │ Alice   ┆ 4           ┆ x           ┆ true  ┆ 1     │
    │ Alice   ┆ 2           ┆ y           ┆ false ┆ 1     │
    │ Alice   ┆ 3           ┆ z           ┆ false ┆ 1     │
    │ Bob     ┆ 1           ┆ x           ┆ false ┆ 0     │
    │ Bob     ┆ 1           ┆ y           ┆ false ┆ 0     │
    │ Bob     ┆ 4           ┆ z           ┆ true  ┆ 1     │
    │ Bob     ┆ 8           ┆ x           ┆ true  ┆ 2     │
    │ Bob     ┆ 3           ┆ y           ┆ false ┆ 2     │
    │ Charlie ┆ 1           ┆ x           ┆ false ┆ 0     │
    │ Charlie ┆ 1           ┆ y           ┆ false ┆ 0     │
    │ Charlie ┆ 7           ┆ z           ┆ true  ┆ 1     │
    │ Charlie ┆ 3           ┆ x           ┆ false ┆ 1     │
    │ Charlie ┆ 2           ┆ y           ┆ false ┆ 1     │

    If uid is not unique, you could .with_row_index() first before exploding and use that in the .over()

    You can then use .group_by() to reassmble the lists.

    (df.explode("time_deltas", "other_field")
       .with_columns(index = (pl.col("time_deltas") > cutoff).cum_sum().over("uid"))
       .group_by("index", "uid", maintain_order=True)
    shape: (6, 4)
    │ index ┆ uid     ┆ time_deltas ┆ other_field     │
    │ ---   ┆ ---     ┆ ---         ┆ ---             │
    │ u32   ┆ str     ┆ list[i64]   ┆ list[str]       │
    │ 1     ┆ Alice   ┆ [4, 2, 3]   ┆ ["x", "y", "z"] │
    │ 0     ┆ Bob     ┆ [1, 1]      ┆ ["x", "y"]      │
    │ 1     ┆ Bob     ┆ [4]         ┆ ["z"]           │
    │ 2     ┆ Bob     ┆ [8, 3]      ┆ ["x", "y"]      │
    │ 0     ┆ Charlie ┆ [1, 1]      ┆ ["x", "y"]      │
    │ 1     ┆ Charlie ┆ [7, 3, 2]   ┆ ["z", "x", "y"] │