pythonpython-polarspolars

Polars - Get column value at another column's min / max value


Given the following polars dataframe:

pl.DataFrame({'A': ['a0', 'a0', 'a1', 'a1'],
              'B': ['b1', 'b2', 'b1', 'b2'],
              'x': [0, 10, 5, 1]})

shape: (4, 3)
┌─────┬─────┬─────┐
│ A   ┆ B   ┆ x   │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═════╪═════╪═════╡
│ a0  ┆ b1  ┆ 0   │
│ a0  ┆ b2  ┆ 10  │
│ a1  ┆ b1  ┆ 5   │
│ a1  ┆ b2  ┆ 1   │
└─────┴─────┴─────┘

I want to add a column y which groups by A and selects the value from B with the maximum corresponding x. The following dataframe should be the result:

┌─────┬─────┬─────┬─────┐
│ A   ┆ B   ┆ x   ┆ y   │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 ┆ str │
╞═════╪═════╪═════╪═════╡
│ a0  ┆ b1  ┆ 0   ┆ b2  │
│ a0  ┆ b2  ┆ 10  ┆ b2  │
│ a1  ┆ b1  ┆ 5   ┆ b1  │
│ a1  ┆ b2  ┆ 1   ┆ b1  │
└─────┴─────┴─────┴─────┘

I've tried various versions of df.with_columns(y=pl.col('B').?.over('A')) without any luck.


Solution

  • You can use pl.Expr.get and pl.Expr.arg_max to obtain the value of B with maximum corresponding value of x. This can be combined with the window function pl.Expr.over to perform the operation separately for each group defined by A.

    df.with_columns(
        pl.col("B").get(pl.col("x").arg_max()).over("A").alias("y")
    )
    
    shape: (4, 4)
    ┌─────┬─────┬─────┬─────┐
    │ A   ┆ B   ┆ x   ┆ y   │
    │ --- ┆ --- ┆ --- ┆ --- │
    │ str ┆ str ┆ i64 ┆ str │
    ╞═════╪═════╪═════╪═════╡
    │ a0  ┆ b1  ┆ 0   ┆ b2  │
    │ a0  ┆ b2  ┆ 10  ┆ b2  │
    │ a1  ┆ b1  ┆ 5   ┆ b1  │
    │ a1  ┆ b2  ┆ 1   ┆ b1  │
    └─────┴─────┴─────┴─────┘