pythonpython-polars

Polars make all groups the same size


Question

I'm trying to make all groups for a given data frame have the same size. In Starting point below, I show an example of a data frame that I whish to transform. In Goal I try to demonstrate what I'm trying to achieve. I want to group by the column group, make all groups have a size of 4, and fill 'missing' values with null - I hope it's clear.

I have tried several approaches but have not been able to figure this one out.

Starting point

dfa = pl.DataFrame(data={'group': ['a', 'a', 'a', 'b', 'b', 'c'],
                         'value': ['a1', 'a2', 'a3', 'b1', 'b2', 'c1']})
┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ str   │
╞═══════╪═══════╡
│ a     ┆ a1    │
│ a     ┆ a2    │
│ a     ┆ a3    │
│ b     ┆ b1    │
│ b     ┆ b2    │
│ c     ┆ c1    │
└───────┴───────┘

Goal

>>> make_groups_uniform(dfa, group_by='group', group_size=4)
┌───────┬───────┐
│ group ┆ value │
│ ---   ┆ ---   │
│ str   ┆ str   │
╞═══════╪═══════╡
│ a     ┆ a1    │
│ a     ┆ a2    │
│ a     ┆ a3    │
│ a     ┆ null  │
│ b     ┆ b1    │
│ b     ┆ b2    │
│ b     ┆ null  │
│ b     ┆ null  │
│ c     ┆ c1    │
│ c     ┆ null  │
│ c     ┆ null  │
│ c     ┆ null  │
└───────┴───────┘

Package version

polars: 1.1.0


Solution

  • You could use pl.repeat() to generate the nulls and .append() them.

    group_size = 4
    
    (df.group_by("group", maintain_order=True)
       .agg(pl.all().append(pl.repeat(None, group_size - pl.len().cast(int))))
       .explode(pl.exclude("group"))
    ) 
    

    Or, you can use pl.len().over("group").max() to calculate the largest group size.

    (df.with_columns(group_size=pl.len().over("group").max())
       .group_by("group", maintain_order=True)
       .agg(pl.all().append(pl.repeat(None, pl.col("group_size") - pl.len().cast(int))))
       .explode(pl.exclude("group"))
    ) 
    
    shape: (9, 3)
    ┌───────┬───────┬────────────┐
    │ group ┆ value ┆ group_size │
    │ ---   ┆ ---   ┆ ---        │
    │ str   ┆ str   ┆ u32        │
    ╞═══════╪═══════╪════════════╡
    │ a     ┆ a1    ┆ 3          │
    │ a     ┆ a2    ┆ 3          │
    │ a     ┆ a3    ┆ 3          │
    │ b     ┆ b1    ┆ 3          │
    │ b     ┆ b2    ┆ 3          │
    │ b     ┆ null  ┆ null       │
    │ c     ┆ c1    ┆ 3          │
    │ c     ┆ null  ┆ null       │
    │ c     ┆ null  ┆ null       │
    └───────┴───────┴────────────┘
    

    ☣ Warning ☣

    pl.len() is currently an unsigned int (but may change in the future)

    We cast to an integer to avoid potential issues with overflow.

    df.group_by("group").agg(int = 2 - pl.len().cast(int), uint32 = 2 - pl.len())
    
    shape: (3, 3)
    ┌───────┬─────┬────────────┐
    │ group ┆ int ┆ uint32     │
    │ ---   ┆ --- ┆ ---        │
    │ str   ┆ i64 ┆ u32        │
    ╞═══════╪═════╪════════════╡
    │ a     ┆ -1  ┆ 4294967295 │ # <- pl.repeat(None, 4294967295) would be bad.
    │ b     ┆ 0   ┆ 0          │
    │ c     ┆ 1   ┆ 1          │
    └───────┴─────┴────────────┘
    

    1. https://github.com/pola-rs/polars/issues/17722