Given the following polars dataframe:
pl.DataFrame({'A': ['a0', 'a0', 'a1', 'a1'],
'B': ['b1', 'b2', 'b1', 'b2'],
'x': [0, 10, 5, 1]})
shape: (4, 3)
┌─────┬─────┬─────┐
│ A ┆ B ┆ x │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═════╪═════╪═════╡
│ a0 ┆ b1 ┆ 0 │
│ a0 ┆ b2 ┆ 10 │
│ a1 ┆ b1 ┆ 5 │
│ a1 ┆ b2 ┆ 1 │
└─────┴─────┴─────┘
I want to add a column y
which groups by A
and selects the value from B
with the maximum corresponding x
. The following dataframe should be the result:
┌─────┬─────┬─────┬─────┐
│ A ┆ B ┆ x ┆ y │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 ┆ str │
╞═════╪═════╪═════╪═════╡
│ a0 ┆ b1 ┆ 0 ┆ b2 │
│ a0 ┆ b2 ┆ 10 ┆ b2 │
│ a1 ┆ b1 ┆ 5 ┆ b1 │
│ a1 ┆ b2 ┆ 1 ┆ b1 │
└─────┴─────┴─────┴─────┘
I've tried various versions of df.with_columns(y=pl.col('B').?.over('A'))
without any luck.
You can use pl.Expr.get
and pl.Expr.arg_max
to obtain the value of B with maximum corresponding value of x. This can be combined with the window function pl.Expr.over
to perform the operation separately for each group defined by A.
df.with_columns(
pl.col("B").get(pl.col("x").arg_max()).over("A").alias("y")
)
shape: (4, 4)
┌─────┬─────┬─────┬─────┐
│ A ┆ B ┆ x ┆ y │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 ┆ str │
╞═════╪═════╪═════╪═════╡
│ a0 ┆ b1 ┆ 0 ┆ b2 │
│ a0 ┆ b2 ┆ 10 ┆ b2 │
│ a1 ┆ b1 ┆ 5 ┆ b1 │
│ a1 ┆ b2 ┆ 1 ┆ b1 │
└─────┴─────┴─────┴─────┘