pythondataframepython-polarspolars

How to set multiple elements conditionally in Polars similar to Pandas?


I am trying to set multiple elements in a Polars DataFrame based on a condition, similar to how it is done in Pandas. Here’s an example in Pandas:

import pandas as pd

df = pd.DataFrame(dict(
    A=[1, 2, 3, 4, 5],
    B=[0, 5, 9, 2, 10],
))

df.loc[df['A'] < df['B'], 'A'] = [100, 210, 320]
print(df)

This updates column A where A < B with [100, 210, 320].

In Polars, I know that updating a DataFrame in place is not possible, and it is fine to return a new DataFrame with the updated elements. I have tried the following methods:

Attempt 1: Using Series.scatter with map_batches

import polars as pl

df = pl.DataFrame(dict(
    A=[1, 2, 3, 4, 5],
    B=[0, 5, 9, 2, 10],
))

def set_elements(cols):
    a, b = cols
    return a.scatter((a < b).arg_true(), [100, 210, 320])

df = df.with_columns(
    pl.map_batches(['A', 'B'], set_elements)
)

Attempt 2: Creating an update DataFrame and using update()

df = df.with_row_index()
df_update = df.filter(pl.col('A') < pl.col('B')).select(
    'index',
    pl.Series('A', [100, 210, 320])
)
df = df.update(df_update, on='index').drop('index')

Both approaches work, but they feel cumbersome compared to the straightforward Pandas syntax.

Question:
Is there a simpler or more idiomatic way in Polars to set multiple elements conditionally in a column, similar to the Pandas loc syntax?


Solution

  • updated You can work around pl.Series.scatter() using pl.Expr.arg_true() or pl.arg_where(), although it would require accessing column as pl.Series:

    df.with_columns(
        df.get_column("A").scatter(
        # or df["A"].scatter
            df.select((pl.col.A < pl.col.B).arg_true()),
            # or df.select(pl.arg_where(pl.col.A < pl.col.B)),
            [100, 210, 320]
        )
    )
    
    shape: (5, 2)
    ┌─────┬─────┐
    │ A   ┆ B   │
    │ --- ┆ --- │
    │ i64 ┆ i64 │
    ╞═════╪═════╡
    │ 1   ┆ 0   │
    │ 100 ┆ 5   │
    │ 210 ┆ 9   │
    │ 4   ┆ 2   │
    │ 320 ┆ 10  │
    └─────┴─────┘
    

    Or, using @Hericks answer you can inplace update it:

    df[df.select((pl.col.A < pl.col.B).arg_true()), "A"] = [100, 210, 320]
    # or df[df.select(pl.arg_where(pl.col.A < pl.col.B)), "A"] = [100, 210, 320]
    
    shape: (5, 2)
    ┌─────┬─────┐
    │ A   ┆ B   │
    │ --- ┆ --- │
    │ i64 ┆ i64 │
    ╞═════╪═════╡
    │ 1   ┆ 0   │
    │ 100 ┆ 5   │
    │ 210 ┆ 9   │
    │ 4   ┆ 2   │
    │ 320 ┆ 10  │
    └─────┴─────┘
    

    original I'd say the way @jqurious did it in this answer is probably as short as you can get.

    idxs = pl.when(pl.col.A < pl.col.B).then(1).cum_sum() - 1
    new_values = pl.lit(pl.Series([100, 210, 320]))
    
    df.with_columns(A = pl.coalesce(new_values.get(idxs), 'A'))
    
    shape: (5, 2)
    ┌─────┬─────┐
    │ A   ┆ B   │
    │ --- ┆ --- │
    │ i64 ┆ i64 │
    ╞═════╪═════╡
    │ 1   ┆ 0   │
    │ 100 ┆ 5   │
    │ 210 ┆ 9   │
    │ 4   ┆ 2   │
    │ 320 ┆ 10  │
    └─────┴─────┘