I want to split the dataframe according to the sublists that given from diving a list into parts where the only value above a cutoff is the first.
e.g. Cutoff = 3
[4,2,3,5,2,1,6,7] => [4,2,3], [5,2,1], [6], [7]
I still need to keep track of the other fields in the dataframe.
I should get the given result from this df
data = {
"uid": ["Alice", "Bob", "Charlie"],
"time_deltas": [
[4,2, 3],
[1,1, 4, 8, 3],
[1,1, 7, 3, 2],
],
"other_field": [["x", "y", "z"], ["x", "y", "z", "x", "y"], ["x", "y", "z", "x", "y"]]
}
df = pl.DataFrame(data)
cutoff = 3
# Split the time_delta column into lists where the maximum time_delta (excluding the first value) is greater than the cutoff. Ensure that the other_field column is also split accordingly.
# Expected Output
# +--------+----------------------+----------------------+
# | uid | time_deltas | other_field |
# | --- | --- | --- |
# | str | list[duration[ms]] | list[str] |
# +--------+----------------------+----------------------+
# | Alice | [4, 2, 3] | ["x", "y", "z"] |
# | Bob | [1, 1] | ["x", "y"] |
# | Bob | [4] | ["z"] |
# | Bob | [8, 3] | ["x", "y"] |
# | Charlie| [1, 1] | ["x", "y"] |
# | Charlie| [7,3,2] | ["z", "x", "y"] |
If you explode/flatten the lists, you can use .cum_sum()
of the comparison .over()
each "group" to assign a group id/index to identify each sublist.
(df.explode("time_deltas", "other_field")
.with_columns(bool = pl.col("time_deltas") > cutoff)
.with_columns(index = pl.col("bool").cum_sum().over("uid"))
)
shape: (13, 5)
┌─────────┬─────────────┬─────────────┬───────┬───────┐
│ uid ┆ time_deltas ┆ other_field ┆ bool ┆ index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ str ┆ bool ┆ u32 │
╞═════════╪═════════════╪═════════════╪═══════╪═══════╡
│ Alice ┆ 4 ┆ x ┆ true ┆ 1 │
│ Alice ┆ 2 ┆ y ┆ false ┆ 1 │
│ Alice ┆ 3 ┆ z ┆ false ┆ 1 │
│ Bob ┆ 1 ┆ x ┆ false ┆ 0 │
│ Bob ┆ 1 ┆ y ┆ false ┆ 0 │
│ Bob ┆ 4 ┆ z ┆ true ┆ 1 │
│ Bob ┆ 8 ┆ x ┆ true ┆ 2 │
│ Bob ┆ 3 ┆ y ┆ false ┆ 2 │
│ Charlie ┆ 1 ┆ x ┆ false ┆ 0 │
│ Charlie ┆ 1 ┆ y ┆ false ┆ 0 │
│ Charlie ┆ 7 ┆ z ┆ true ┆ 1 │
│ Charlie ┆ 3 ┆ x ┆ false ┆ 1 │
│ Charlie ┆ 2 ┆ y ┆ false ┆ 1 │
└─────────┴─────────────┴─────────────┴───────┴───────┘
If uid
is not unique, you could .with_row_index()
first before exploding and use that in the .over()
You can then use .group_by()
to reassmble the lists.
(df.explode("time_deltas", "other_field")
.with_columns(index = (pl.col("time_deltas") > cutoff).cum_sum().over("uid"))
.group_by("index", "uid", maintain_order=True)
.all()
)
shape: (6, 4)
┌───────┬─────────┬─────────────┬─────────────────┐
│ index ┆ uid ┆ time_deltas ┆ other_field │
│ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ str ┆ list[i64] ┆ list[str] │
╞═══════╪═════════╪═════════════╪═════════════════╡
│ 1 ┆ Alice ┆ [4, 2, 3] ┆ ["x", "y", "z"] │
│ 0 ┆ Bob ┆ [1, 1] ┆ ["x", "y"] │
│ 1 ┆ Bob ┆ [4] ┆ ["z"] │
│ 2 ┆ Bob ┆ [8, 3] ┆ ["x", "y"] │
│ 0 ┆ Charlie ┆ [1, 1] ┆ ["x", "y"] │
│ 1 ┆ Charlie ┆ [7, 3, 2] ┆ ["z", "x", "y"] │
└───────┴─────────┴─────────────┴─────────────────┘