pythonpython-polarspolars

Get a grouped sum in polars, but keep all individual rows


I am breaking my head over this probably pretty simply question and I just can't find the answer anywhere. I want to create a new column with a grouped sum of another column, but I want to keep all individual rows. So, this is what the docs say:

import polars as pl

df = pl.DataFrame(
    {
        "a": ["a", "b", "a", "b", "c"],
        "b": [1, 2, 1, 3, 3],
    }
)

df.group_by("a").agg(pl.col("b").sum())  

The output of this would be:

shape: (3, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a   ┆ 2   │
│ c   ┆ 3   │
│ b   ┆ 5   │
└─────┴─────┘

However, what I need would be this:

shape: (5, 3)
┌─────┬─────┬────────┐
│ a   ┆ b   ┆ sum(b) │
│ --- ┆ --- ┆ ------ │
│ str ┆ i64 ┆ i64    │
╞═════╪═════╪════════╡
│ a   ┆ 1   ┆ 2      │
│ b   ┆ 2   ┆ 5      │
│ a   ┆ 1   ┆ 2      │
│ b   ┆ 3   ┆ 5      │
│ c   ┆ 3   ┆ 3      │
└─────┴─────┴────────┘

I could create the sum in a separate df and then join it with the original one, but I am pretty sure, there is an easier solution.


Solution

  • All you need is a window function:

    df.with_columns(
        b_sum=pl.col("b").sum().over(pl.col("a"))
    )
    
    
    shape: (5, 3)
    ┌─────┬─────┬───────┐
    │ a   ┆ b   ┆ b_sum │
    │ --- ┆ --- ┆ ---   │
    │ str ┆ i64 ┆ i64   │
    ╞═════╪═════╪═══════╡
    │ a   ┆ 1   ┆ 2     │
    │ b   ┆ 2   ┆ 5     │
    │ a   ┆ 1   ┆ 2     │
    │ b   ┆ 3   ┆ 5     │
    │ c   ┆ 3   ┆ 3     │
    └─────┴─────┴───────┘