pythonmaxpython-polars

Getting min/max column name in Polars


In polars I can get the horizontal max (maximum value of a set of columns for reach row) like this:

df = pl.DataFrame(
    {
        "a": [1, 8, 3],
        "b": [4, 5, None],
    }
)

df.with_columns(max = pl.max_horizontal("a", "b"))
┌─────┬──────┬─────┐
│ a   ┆ b    ┆ max │
│ --- ┆ ---  ┆ --- │
│ i64 ┆ i64  ┆ i64 │
╞═════╪══════╪═════╡
│ 1   ┆ 4    ┆ 4   │
│ 8   ┆ 5    ┆ 8   │
│ 3   ┆ null ┆ 3   │
└─────┴──────┴─────┘

This corresponds to Pandas df[["a", "b"]].max(axis=1).

Now, how do I get the column names instead of the actual max value? In other words, what is the Polars version of Pandas' df[CHANGE_COLS].idxmax(axis=1)?

The expected output would be:

┌─────┬──────┬─────┐
│ a   ┆ b    ┆ max │
│ --- ┆ ---  ┆ --- │
│ i64 ┆ i64  ┆ str │
╞═════╪══════╪═════╡
│ 1   ┆ 4    ┆ b   │
│ 8   ┆ 5    ┆ a   │
│ 3   ┆ null ┆ a   │
└─────┴──────┴─────┘

Solution

  • You can concatenate the elements into a list using pl.concat_list, get the index of the largest element using pl.Expr.list.arg_max, and replace the index with the column name using pl.Expr.replace.

    mapping = {0: "a", 1: "b"}
    (
        df
        .with_columns(
            pl.concat_list(["a", "b"]).list.arg_max().replace(mapping).alias("max_col")
        )
    )
    

    This can all be wrapped into a function to also handle the creation of the mapping dict.

    def max_col(cols) -> str:
        mapping = dict(enumerate(cols))
        return pl.concat_list(cols).list.arg_max().replace(mapping)
    
    df.with_columns(max_col(["a", "b"]).alias("max_col"))
    

    Output.

    shape: (3, 3)
    ┌─────┬──────┬─────────┐
    │ a   ┆ b    ┆ max_col │
    │ --- ┆ ---  ┆ ---     │
    │ i64 ┆ i64  ┆ str     │
    ╞═════╪══════╪═════════╡
    │ 1   ┆ 4    ┆ b       │
    │ 8   ┆ 5    ┆ a       │
    │ 3   ┆ null ┆ a       │
    └─────┴──────┴─────────┘