pythondataframewindow-functionspython-polars

Polars: Nesting `over` calls


Context. I have written a function that computes the mean of all elements in a column except the elements in the current group.

df = pl.DataFrame({
    "group": ["A", "A", "B", "B", "C", "C"],
    "value": [1, 3, 2, 4, 3, 5],
})

def sum_excl_group(val_exp: pl.Expr, group_expr: pl.Expr) -> pl.Expr:
    return val_exp.sum() - val_exp.sum().over(group_expr)

def count_non_null_excl_group(val_exp: pl.Expr, group_expr: pl.Expr) -> pl.Expr:
    return val_exp.is_not_null().sum() - val_exp.is_not_null().sum().over(group_expr)

def mean_excl_group(val_exp: pl.Expr, group_expr: pl.Expr) -> pl.Expr:
    return sum_excl_group(val_exp,group_expr) / count_non_null_excl_group(val_exp, group_expr)

(
    df
    .with_columns(
        mean_excl_group(pl.col("value"), pl.col("group")).alias("mean_excl_group"),
    )
)

This gives the expected result.

shape: (6, 3)
┌───────┬───────┬─────────────────┐
│ group ┆ value ┆ mean_excl_group │
│ ---   ┆ ---   ┆ ---             │
│ str   ┆ i64   ┆ f64             │
╞═══════╪═══════╪═════════════════╡
│ A     ┆ 1     ┆ 3.5             │
│ A     ┆ 3     ┆ 3.5             │
│ B     ┆ 2     ┆ 3.0             │
│ B     ┆ 4     ┆ 3.0             │
│ C     ┆ 3     ┆ 2.5             │
│ C     ┆ 5     ┆ 2.5             │
└───────┴───────┴─────────────────┘

Problem. Now, I am facing the issue that I would like to run this function within an over context to obtain the mean of all elements in a group, except the elements in the current subgroup.

I would've expected the following to work.

df = pl.DataFrame({
    "group": ["A", "A", "B", "B", "C", "C"],
    "subgroup": ["a", "b", "c", "d", "e", "f"],
    "value": [1, 3, 2, 4, 3, 5],
})

(
    df
    .with_columns(
        mean_excl_group(pl.col("value"), pl.col("subgroup")).over("group").alias("mean_excl_group"),
    )
)

but get an InvalidOperationError

InvalidOperationError: window expression not allowed in aggregation

Attempt. For now, I have "solved" this issue by avoiding the nested over calls.

def sum_excl_group(val_exp: pl.Expr, coarse_group_expr: pl.Expr, fine_group_expr: pl.Expr) -> pl.Expr:
    return val_exp.sum().over(coarse_group_expr) - val_exp.sum().over(fine_group_expr)

def count_non_null_excl_group(val_exp: pl.Expr, coarse_group_expr: pl.Expr, fine_group_expr: pl.Expr) -> pl.Expr:
    return val_exp.is_not_null().sum().over(coarse_group_expr) - val_exp.is_not_null().sum().over(fine_group_expr)

def mean_excl_group(val_exp: pl.Expr, coarse_group_expr: pl.Expr, fine_group_expr: pl.Expr) -> pl.Expr:
    return sum_excl_group(val_exp, coarse_group_expr, fine_group_expr) / count_non_null_excl_group(val_exp, coarse_group_expr, fine_group_expr)

(
    df
    .with_columns(
        mean_excl_group(pl.col("value"), pl.col("group"), pl.col("subgroup")).alias("mean_excl_group"),
    )
)

This gives the expected result.

shape: (6, 4)
┌───────┬──────────┬───────┬─────────────────┐
│ group ┆ subgroup ┆ value ┆ mean_excl_group │
│ ---   ┆ ---      ┆ ---   ┆ ---             │
│ str   ┆ str      ┆ i64   ┆ f64             │
╞═══════╪══════════╪═══════╪═════════════════╡
│ A     ┆ a        ┆ 1     ┆ 3.0             │
│ A     ┆ b        ┆ 3     ┆ 1.0             │
│ B     ┆ c        ┆ 2     ┆ 4.0             │
│ B     ┆ d        ┆ 4     ┆ 2.0             │
│ C     ┆ e        ┆ 3     ┆ 5.0             │
│ C     ┆ f        ┆ 5     ┆ 3.0             │
└───────┴──────────┴───────┴─────────────────┘

However, this requires me to pass both granularities to all functions making the code more bloated than (I hope) it needs to be. Is there a cleaner more polaric way to solve the problem of nested over calls?


Solution

  • This is now valid as of Polars 1.36.0

    (
        df
        .with_columns(
            mean_excl_group(pl.col("value"), pl.col("subgroup")).over("group").alias("mean_excl_group"),
        )
    )
    # shape: (6, 4)
    # ┌───────┬──────────┬───────┬─────────────────┐
    # │ group ┆ subgroup ┆ value ┆ mean_excl_group │
    # │ ---   ┆ ---      ┆ ---   ┆ ---             │
    # │ str   ┆ str      ┆ i64   ┆ f64             │
    # ╞═══════╪══════════╪═══════╪═════════════════╡
    # │ A     ┆ a        ┆ 1     ┆ 3.0             │
    # │ A     ┆ b        ┆ 3     ┆ 1.0             │
    # │ B     ┆ c        ┆ 2     ┆ 4.0             │
    # │ B     ┆ d        ┆ 4     ┆ 2.0             │
    # │ C     ┆ e        ┆ 3     ┆ 5.0             │
    # │ C     ┆ f        ┆ 5     ┆ 3.0             │
    # └───────┴──────────┴───────┴─────────────────┘