pythondataframepython-polars

How to run computations on other rows efficiently


I am working with a Polars DataFrame and need to perform computations on each row using values from other rows. Currently, I am using the map_elements method, but it is not efficient.

In the following example, I add two new columns to a DataFrame:

  1. sum_lower: The sum of all elements that are smaller than the current element.
  2. max_other: The maximum value from the DataFrame, excluding the current element.

Here is my current implementation:

import polars as pl

COL_VALUE = "value"

def fun_sum_lower(current_row, df):
    tmp_df = df.filter(pl.col(COL_VALUE) < current_row[COL_VALUE])
    sum_lower = tmp_df.select(pl.sum(COL_VALUE)).item()
    return sum_lower

def fun_max_other(current_row, df):
    tmp_df = df.filter(pl.col(COL_VALUE) != current_row[COL_VALUE])
    max_other = tmp_df.select(pl.col(COL_VALUE)).max().item()
    return max_other

if __name__ == '__main__':
    df = pl.DataFrame({COL_VALUE: [3, 7, 1, 9, 4]})

    df = df.with_columns(
        pl.struct([COL_VALUE])
        .map_elements(lambda row: fun_sum_lower(row, df), return_dtype=pl.Int64)
        .alias("sum_lower")
    )

    df = df.with_columns(
        pl.struct([COL_VALUE])
        .map_elements(lambda row: fun_max_other(row, df), return_dtype=pl.Int64)
        .alias("max_other")
    )

    print(df)

The output of the above code is:

shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ ---   ┆ ---       ┆ ---       │
│ i64   ┆ i64       ┆ i64       │
╞═══════╪═══════════╪═══════════╡
│ 3     ┆ 1         ┆ 9         │
│ 7     ┆ 8         ┆ 9         │
│ 1     ┆ 0         ┆ 9         │
│ 9     ┆ 15        ┆ 7         │
│ 4     ┆ 4         ┆ 9         │
└───────┴───────────┴───────────┘

While this code works, it is not efficient due to the use of lambdas and row-wise operations.

Is there a more efficient way to achieve this in Polars, without using lambdas, iterating over rows, or running Python code?

I also tried using Polars methods: cum_sum, group_by_dynamic, and rolling, but I don't think those can be used for this task.


Solution

  • For your specific use case you don't really need join, you can calculate values with window functions.

    (
        df
        .sort("value")
        .with_columns(
            sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
            max_other =
            pl.when(pl.col.value.max() != pl.col.value)
            .then(pl.col.value.max())
            .otherwise(pl.col.value.bottom_k(2).min()) 
        )
    )
    
    shape: (5, 3)
    ┌───────┬───────────┬───────────┐
    │ value ┆ sum_lower ┆ max_other │
    │ ---   ┆ ---       ┆ ---       │
    │ i64   ┆ i64       ┆ i64       │
    ╞═══════╪═══════════╪═══════════╡
    │ 1     ┆ 0         ┆ 9         │
    │ 3     ┆ 1         ┆ 9         │
    │ 4     ┆ 4         ┆ 9         │
    │ 7     ┆ 8         ┆ 9         │
    │ 9     ┆ 15        ┆ 7         │
    └───────┴───────────┴───────────┘
    

    You can also use pl.DataFrame.with_row_index() to keep current order so you can revert to it at the end with pl.DataFrame.sort().

    (
        df.with_row_index()
        .sort("value")
        .with_columns(
            sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
            max_other =
            pl.when(pl.col.value.max() != pl.col.value)
            .then(pl.col.value.max())
            .otherwise(pl.col.value.bottom_k(2).min()) 
        )
        .sort("index")
        .drop("index")
    )
    

    Another possible solution would be to use DuckDB integration with Polars.

    Using window functions, getting advantage of excellent DuckDB windows framing options.

    import duckdb
    
    duckdb.sql("""
        select
            d.value,
            coalesce(sum(d.value) over(
                order by d.value
                rows unbounded preceding
                exclude current row
            ), 0) as sum_lower,
            max(d.value) over(
                rows between unbounded preceding and unbounded following
                exclude current row
            ) as max_other
        from df as d
    """).pl()
    
    shape: (5, 3)
    ┌───────┬───────────────┬───────────┐
    │ value ┆ sum_lower     ┆ max_other │
    │ ---   ┆ ---           ┆ ---       │
    │ i64   ┆ decimal[38,0] ┆ i64       │
    ╞═══════╪═══════════════╪═══════════╡
    │ 1     ┆ 0             ┆ 9         │
    │ 3     ┆ 1             ┆ 9         │
    │ 4     ┆ 4             ┆ 9         │
    │ 7     ┆ 8             ┆ 9         │
    │ 9     ┆ 15            ┆ 7         │
    └───────┴───────────────┴───────────┘
    

    Or using lateral join:

    import duckdb
    
    duckdb.sql("""
        select
            d.value,
            coalesce(s.value, 0) as sum_lower,
            m.value as max_other
        from df as d,
            lateral (select sum(t.value) as value from df as t where t.value < d.value) as s,
            lateral (select max(t.value) as value from df as t where t.value != d.value) as m
    """).pl()
    
    shape: (5, 3)
    ┌───────┬───────────┬───────────┐
    │ value ┆ sum_lower ┆ max_other │
    │ ---   ┆ ---       ┆ ---       │
    │ i64   ┆ i64       ┆ i64       │
    ╞═══════╪═══════════╪═══════════╡
    │ 3     ┆ 1         ┆ 9         │
    │ 7     ┆ 8         ┆ 9         │
    │ 1     ┆ 0         ┆ 9         │
    │ 9     ┆ 15        ┆ 7         │
    │ 4     ┆ 4         ┆ 9         │
    └───────┴───────────┴───────────┘
    

    duplicate values

    pure polars solution above works well if there're no duplicate values, but if there are, you can also work around it. Here're 2 examples depending on whether you want to keep original order or not:

    # not keeping original order
    (
        df
        .select(pl.col.value.value_counts()).unnest("value")
        .sort("value")
        .with_columns(
            sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
            max_other =
            pl.when(pl.col.value.max() != pl.col.value)
            .then(pl.col.value.max())
            .otherwise(pl.col.value.bottom_k(2).min()),
            value = pl.col.value.repeat_by("count")
        ).drop("count").explode("value")
    )
    
    # keeping original order
    (
        df.with_row_index()
        .group_by("value").agg("index")
        .sort("value")
        .with_columns(
            sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
            max_other =
            pl.when(pl.col.value.max() != pl.col.value)
            .then(pl.col.value.max())
            .otherwise(pl.col.value.bottom_k(2).min()) 
        )
        .explode("index")
        .sort("index")
        .drop("index")
    )