I am trying to set multiple elements in a Polars DataFrame based on a condition, similar to how it is done in Pandas. Here’s an example in Pandas:
import pandas as pd
df = pd.DataFrame(dict(
A=[1, 2, 3, 4, 5],
B=[0, 5, 9, 2, 10],
))
df.loc[df['A'] < df['B'], 'A'] = [100, 210, 320]
print(df)
This updates column A
where A < B
with [100, 210, 320]
.
In Polars, I know that updating a DataFrame in place is not possible, and it is fine to return a new DataFrame with the updated elements. I have tried the following methods:
Series.scatter
with map_batches
import polars as pl
df = pl.DataFrame(dict(
A=[1, 2, 3, 4, 5],
B=[0, 5, 9, 2, 10],
))
def set_elements(cols):
a, b = cols
return a.scatter((a < b).arg_true(), [100, 210, 320])
df = df.with_columns(
pl.map_batches(['A', 'B'], set_elements)
)
update()
df = df.with_row_index()
df_update = df.filter(pl.col('A') < pl.col('B')).select(
'index',
pl.Series('A', [100, 210, 320])
)
df = df.update(df_update, on='index').drop('index')
Both approaches work, but they feel cumbersome compared to the straightforward Pandas syntax.
Question:
Is there a simpler or more idiomatic way in Polars to set multiple elements conditionally in a column, similar to the Pandas loc
syntax?
updated
You can work around pl.Series.scatter()
using pl.Expr.arg_true()
or pl.arg_where()
, although it would require accessing column as pl.Series
:
df.with_columns(
df.get_column("A").scatter(
# or df["A"].scatter
df.select((pl.col.A < pl.col.B).arg_true()),
# or df.select(pl.arg_where(pl.col.A < pl.col.B)),
[100, 210, 320]
)
)
shape: (5, 2)
┌─────┬─────┐
│ A ┆ B │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1 ┆ 0 │
│ 100 ┆ 5 │
│ 210 ┆ 9 │
│ 4 ┆ 2 │
│ 320 ┆ 10 │
└─────┴─────┘
Or, using @Hericks answer you can inplace update it:
df[df.select((pl.col.A < pl.col.B).arg_true()), "A"] = [100, 210, 320]
# or df[df.select(pl.arg_where(pl.col.A < pl.col.B)), "A"] = [100, 210, 320]
shape: (5, 2)
┌─────┬─────┐
│ A ┆ B │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1 ┆ 0 │
│ 100 ┆ 5 │
│ 210 ┆ 9 │
│ 4 ┆ 2 │
│ 320 ┆ 10 │
└─────┴─────┘
original I'd say the way @jqurious did it in this answer is probably as short as you can get.
idxs = pl.when(pl.col.A < pl.col.B).then(1).cum_sum() - 1
new_values = pl.lit(pl.Series([100, 210, 320]))
df.with_columns(A = pl.coalesce(new_values.get(idxs), 'A'))
shape: (5, 2)
┌─────┬─────┐
│ A ┆ B │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1 ┆ 0 │
│ 100 ┆ 5 │
│ 210 ┆ 9 │
│ 4 ┆ 2 │
│ 320 ┆ 10 │
└─────┴─────┘