In polars I can get the horizontal max (maximum value of a set of columns for reach row) like this:
df = pl.DataFrame(
{
"a": [1, 8, 3],
"b": [4, 5, None],
}
)
df.with_columns(max = pl.max_horizontal("a", "b"))
┌─────┬──────┬─────┐
│ a ┆ b ┆ max │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪══════╪═════╡
│ 1 ┆ 4 ┆ 4 │
│ 8 ┆ 5 ┆ 8 │
│ 3 ┆ null ┆ 3 │
└─────┴──────┴─────┘
This corresponds to Pandas df[["a", "b"]].max(axis=1)
.
Now, how do I get the column names instead of the actual max value?
In other words, what is the Polars version of Pandas' df[CHANGE_COLS].idxmax(axis=1)
?
The expected output would be:
┌─────┬──────┬─────┐
│ a ┆ b ┆ max │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪══════╪═════╡
│ 1 ┆ 4 ┆ b │
│ 8 ┆ 5 ┆ a │
│ 3 ┆ null ┆ a │
└─────┴──────┴─────┘
You can concatenate the elements into a list using pl.concat_list
, get the index of the largest element using pl.Expr.list.arg_max
, and replace the index with the column name using pl.Expr.replace
.
mapping = {0: "a", 1: "b"}
(
df
.with_columns(
pl.concat_list(["a", "b"]).list.arg_max().replace(mapping).alias("max_col")
)
)
This can all be wrapped into a function to also handle the creation of the mapping dict.
def max_col(cols) -> str:
mapping = dict(enumerate(cols))
return pl.concat_list(cols).list.arg_max().replace(mapping)
df.with_columns(max_col(["a", "b"]).alias("max_col"))
Output.
shape: (3, 3)
┌─────┬──────┬─────────┐
│ a ┆ b ┆ max_col │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪══════╪═════════╡
│ 1 ┆ 4 ┆ b │
│ 8 ┆ 5 ┆ a │
│ 3 ┆ null ┆ a │
└─────┴──────┴─────────┘