I'm trying to make all groups for a given data frame have the same size. In Starting point below, I show an example of a data frame that I whish to transform. In Goal I try to demonstrate what I'm trying to achieve. I want to group by the column group
, make all groups have a size of 4
, and fill 'missing' values with null
- I hope it's clear.
I have tried several approaches but have not been able to figure this one out.
Starting point
dfa = pl.DataFrame(data={'group': ['a', 'a', 'a', 'b', 'b', 'c'],
'value': ['a1', 'a2', 'a3', 'b1', 'b2', 'c1']})
┌───────┬───────┐
│ group ┆ value │
│ --- ┆ --- │
│ str ┆ str │
╞═══════╪═══════╡
│ a ┆ a1 │
│ a ┆ a2 │
│ a ┆ a3 │
│ b ┆ b1 │
│ b ┆ b2 │
│ c ┆ c1 │
└───────┴───────┘
Goal
>>> make_groups_uniform(dfa, group_by='group', group_size=4)
┌───────┬───────┐
│ group ┆ value │
│ --- ┆ --- │
│ str ┆ str │
╞═══════╪═══════╡
│ a ┆ a1 │
│ a ┆ a2 │
│ a ┆ a3 │
│ a ┆ null │
│ b ┆ b1 │
│ b ┆ b2 │
│ b ┆ null │
│ b ┆ null │
│ c ┆ c1 │
│ c ┆ null │
│ c ┆ null │
│ c ┆ null │
└───────┴───────┘
Package version
polars: 1.1.0
You could use pl.repeat()
to generate the nulls and .append()
them.
group_size = 4
(df.group_by("group", maintain_order=True)
.agg(pl.all().append(pl.repeat(None, group_size - pl.len().cast(int))))
.explode(pl.exclude("group"))
)
Or, you can use pl.len().over("group").max()
to calculate the largest group size.
(df.with_columns(group_size=pl.len().over("group").max())
.group_by("group", maintain_order=True)
.agg(pl.all().append(pl.repeat(None, pl.col("group_size") - pl.len().cast(int))))
.explode(pl.exclude("group"))
)
shape: (9, 3)
┌───────┬───────┬────────────┐
│ group ┆ value ┆ group_size │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ u32 │
╞═══════╪═══════╪════════════╡
│ a ┆ a1 ┆ 3 │
│ a ┆ a2 ┆ 3 │
│ a ┆ a3 ┆ 3 │
│ b ┆ b1 ┆ 3 │
│ b ┆ b2 ┆ 3 │
│ b ┆ null ┆ null │
│ c ┆ c1 ┆ 3 │
│ c ┆ null ┆ null │
│ c ┆ null ┆ null │
└───────┴───────┴────────────┘
pl.len()
is currently an unsigned int (but may change in the future)
We cast to an integer to avoid potential issues with overflow.
df.group_by("group").agg(int = 2 - pl.len().cast(int), uint32 = 2 - pl.len())
shape: (3, 3)
┌───────┬─────┬────────────┐
│ group ┆ int ┆ uint32 │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ u32 │
╞═══════╪═════╪════════════╡
│ a ┆ -1 ┆ 4294967295 │ # <- pl.repeat(None, 4294967295) would be bad.
│ b ┆ 0 ┆ 0 │
│ c ┆ 1 ┆ 1 │
└───────┴─────┴────────────┘
.clip()
can also be used to enforce lower/upper bounds.