pythonpython-polars

Compute group-wise residual for polars data frame


I am in a situation where I have a data frame with X and X values as well as two groups GROUP1 and GROUP2. Looping over both of the groups, I want to fit a linear model against the X and Y data and the subtract the fit from the true data to get a residual.

I'm currently implementing this in the following way:

import polars as pl
import numpy as np

# --- Sample DataFrame for demonstration purposes 
df = pl.DataFrame(
    {
        "GROUP1": [1, 1, 1, 2, 2, 2],
        "GROUP2": ["A", "A", "A", "B", "B", "B"],
        "X": [0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
        "Y": [5.0, 7.0, 9.0, 3.0, 4.0, 6.0],
    }
)


# --- Function to subtract linear best fit per group 
def subtract_linear_best_fit(df: pl.DataFrame) -> pl.DataFrame:
    result = []
    for _, subdf in df.group_by(["GROUP1", "GROUP2"]):
        x = subdf["X"].to_numpy()
        y = subdf["Y"].to_numpy()

        a, b = np.polyfit(x, y, 1)
        residuals = y - (a * x + b)

        result.append(subdf.with_columns(pl.Series("residual", residuals)))
    return pl.concat(result)


# --- Apply function 
df_with_residuals = subtract_linear_best_fit(df)
print(df_with_residuals)

But this does not seem nice as it does not make use of .group_by(...).agg(...) or .with_columns((...).over(...)). I tried both these approaches but I either lost columns from the original data frame or just computed a summary. But I want to have a data frame of the same height, just with one more column.

Is there any way to avoid concatenating data frames inside the loop? Ideally there would be something like .group_by().pipe() or .pipe().over().


Solution

  • You can use a Struct to send multiple columns to .map_batches()

    def subtract_linear_best_fit(s: pl.Series) -> pl.Series:
        x, y = s.struct.unnest()
        a, b = np.polyfit(x, y, 1)
        return y - (a * x + b)
        
    df.with_columns(
        pl.struct("X", "Y").map_batches(subtract_linear_best_fit)
          .over("GROUP1", "GROUP2")
          .alias("residual")
    )
    
    shape: (6, 5)
    ┌────────┬────────┬─────┬─────┬─────────────┐
    │ GROUP1 ┆ GROUP2 ┆ X   ┆ Y   ┆ residual    │
    │ ---    ┆ ---    ┆ --- ┆ --- ┆ ---         │
    │ i64    ┆ str    ┆ f64 ┆ f64 ┆ f64         │
    ╞════════╪════════╪═════╪═════╪═════════════╡
    │ 1      ┆ A      ┆ 0.0 ┆ 5.0 ┆ -1.7764e-15 │
    │ 1      ┆ A      ┆ 1.0 ┆ 7.0 ┆ -1.7764e-15 │
    │ 1      ┆ A      ┆ 2.0 ┆ 9.0 ┆ 0.0         │
    │ 2      ┆ B      ┆ 0.0 ┆ 3.0 ┆ 0.166667    │
    │ 2      ┆ B      ┆ 1.0 ┆ 4.0 ┆ -0.333333   │
    │ 2      ┆ B      ┆ 2.0 ┆ 6.0 ┆ 0.166667    │
    └────────┴────────┴─────┴─────┴─────────────┘