pythondataframepython-polars

Polars: Use column values to reference other column in when / then expression


I have a Polars dataframe where I'd like to derive a new column using a when/then expression. The values of the new column should be taken from a different column in the same dataframe. However, the column from which to take the values differs from row to row.

Here's a simple example:

df = pl.DataFrame(
    {
        "frequency": [0.5, None, None, None],
        "frequency_ref": ["a", "z", "a", "a"],
        "a": [1, 2, 3, 4],
        "z": [5, 6, 7, 8],
    }
)

The resulting dataframe should look like this:

res = pl.DataFrame(
    {
        "frequency": [0.5, None, None, None],
        "frequency_ref": ["a", "z", "a", "a"],
        "a": [1, 2, 3, 4],
        "z": [5, 6, 7, 8],
        "res": [0.5, 6, 3, 4]
    }
)

I tried to create a dynamic reference using a nested pl.col:

# Case 1) Fixed value is given
fixed_freq_condition = pl.col("frequency").is_not_null() & pl.col("frequency").is_not_nan()
# Case 2) Reference to distribution data is given
ref_freq_condition = pl.col("frequency_ref").is_not_null()

# Apply the conditions to calculate res
df = df.with_columns(
    pl.when(fixed_freq_condition)
    .then(pl.col("frequency"))
    .when(ref_freq_condition)
    .then(
      pl.col(pl.col("frequency_ref"))
    )
    .otherwise(0.0)
    .alias("res"),
)

Which fails with TypeError: invalid input for "col". Expected "str" or "DataType", got 'Expr'.

What works (but only as an intermediate solution) is by explicitly listing every possible column value in a very long when/then expression. This is far from optimal as the column names might change in the future and produces a lot of code repititon.

df = df.with_columns(
    pl.when(fixed_freq_condition)
    .then(pl.col("frequency"))
    .when(pl.col("frequency_ref") == "a")
    .then(pl.col("a"))
    # ... more entries
    .when(pl.col("frequency_ref") == "z")
    .then(pl.col("z"))
    .otherwise(0.0)
    .alias("res"),
)

Solution

  • You could build the when/then in a loop:

    freq_refs = df.get_column("frequency_ref")
    expr = pl.when(False).then(None)  # dummy starter value
    for c in freq_refs:
        expr = expr.when(pl.col("frequency_ref") == c).then(pl.col(c))
    expr = expr.otherwise(0)
    
    
    # Apply the conditions to calculate res
    df = df.with_columns(
        pl.when(fixed_freq_condition)
        .then(pl.col("frequency"))
        .when(ref_freq_condition)
        .then(expr)
        .otherwise(0.0)
        .alias("res"),
    )
    
    df
    
    shape: (4, 5)
    ┌───────────┬───────────────┬─────┬─────┬─────┐
    │ frequency ┆ frequency_ref ┆ a   ┆ b   ┆ res │
    │ ---       ┆ ---           ┆ --- ┆ --- ┆ --- │
    │ i64       ┆ str           ┆ i64 ┆ i64 ┆ f64 │
    ╞═══════════╪═══════════════╪═════╪═════╪═════╡
    │ 1         ┆ a             ┆ 1   ┆ 5   ┆ 1.0 │
    │ null      ┆ b             ┆ 2   ┆ 6   ┆ 6.0 │
    │ null      ┆ a             ┆ 3   ┆ 7   ┆ 3.0 │
    │ null      ┆ a             ┆ 4   ┆ 8   ┆ 4.0 │
    └───────────┴───────────────┴─────┴─────┴─────┘