I'm trying to aggregate some rows in my dataframe with a list[str]
column. For each id
I need the intersection of all the lists in the group. Not sure if I'm just overthinking it but I can't provide a solution right now. Any help please?
df = pl.DataFrame(
{"id": [1,1,2,2,3,3],
"values": [["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"], ["A", "B"], ["B", "C"]]
}
)
Expected output
shape: (3, 2)
┌─────┬───────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[str] │
╞═════╪═══════════╡
│ 1 ┆ ["B"] │
│ 2 ┆ ["B"] │
│ 3 ┆ ["B"] │
└─────┴───────────┘
I've tried some stuff without success
df.group_by("id").agg(
pl.reduce(function=lambda acc, x: acc.list.set_intersection(x),
exprs=pl.col("values"))
)
# shape: (3, 2)
# ┌─────┬──────────────────────────┐
# │ id ┆ values │
# │ --- ┆ --- │
# │ i64 ┆ list[list[str]] │
# ╞═════╪══════════════════════════╡
# │ 1 ┆ [["A", "B"], ["B", "C"]] │
# │ 3 ┆ [["A", "B"], ["B", "C"]] │
# │ 2 ┆ [["A", "B"], ["B", "C"]] │
# └─────┴──────────────────────────┘
Another one
df.group_by("id").agg(
pl.reduce(function=lambda acc, x: acc.list.set_intersection(x),
exprs=pl.col("values").explode())
)
# shape: (3, 2)
# ┌─────┬──────────────────────┐
# │ id ┆ values │
# │ --- ┆ --- │
# │ i64 ┆ list[str] │
# ╞═════╪══════════════════════╡
# │ 3 ┆ ["A", "B", "B", "C"] │
# │ 1 ┆ ["A", "B", "B", "C"] │
# │ 2 ┆ ["A", "B", "B", "C"] │
# └─────┴──────────────────────┘
I'm not sure if this is as simple as it may first seem.
You could get rid of the lists and use "regular" Polars functionality.
One way to check if a value is contained in each row of the id
group is to count the number of unique (distinct) row numbers per id, values
group.
(df.with_columns(group_len = pl.len().over("id"))
.with_row_index()
.explode("values")
.with_columns(n_unique = pl.col.index.n_unique().over("id", "values"))
)
shape: (12, 5)
┌────────┬─────┬────────┬───────────┬──────────┐
│ index ┆ id ┆ values ┆ group_len ┆ n_unique │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ str ┆ u32 ┆ u32 │
╞════════╪═════╪════════╪═══════════╪══════════╡
│ 0 ┆ 1 ┆ A ┆ 2 ┆ 1 │
│ 0 ┆ 1 ┆ B ┆ 2 ┆ 2 │ # index = [0, 1]
│ 1 ┆ 1 ┆ B ┆ 2 ┆ 2 │
│ 1 ┆ 1 ┆ C ┆ 2 ┆ 1 │
│ 2 ┆ 2 ┆ A ┆ 2 ┆ 1 │
│ 2 ┆ 2 ┆ B ┆ 2 ┆ 2 │ # index = [2, 3]
│ 3 ┆ 2 ┆ B ┆ 2 ┆ 2 │
│ 3 ┆ 2 ┆ C ┆ 2 ┆ 1 │
│ 4 ┆ 3 ┆ A ┆ 2 ┆ 1 │
│ 4 ┆ 3 ┆ B ┆ 2 ┆ 2 │ # index = [4, 5]
│ 5 ┆ 3 ┆ B ┆ 2 ┆ 2 │
│ 5 ┆ 3 ┆ C ┆ 2 ┆ 1 │
└────────┴─────┴────────┴───────────┴──────────┘
You can filter those, and rebuild the lists with .group_by()
(df.with_columns(pl.len().over("id").alias("group_len"))
.with_row_index()
.explode("values")
.filter(
pl.col.index.n_unique().over("id", "values")
== pl.col.group_len
)
.group_by("id", maintain_order=True)
.agg(pl.col.values.unique())
)
shape: (3, 2)
┌─────┬───────────┐
│ idx ┆ values │
│ --- ┆ --- │
│ i64 ┆ list[str] │
╞═════╪═══════════╡
│ 1 ┆ ["B"] │
│ 2 ┆ ["B"] │
│ 3 ┆ ["B"] │
└─────┴───────────┘