pythonpython-polarspolars

How to add a temp column while avoiding column name conflicts


I have a custom function that does some data cleaning on a polars DataFrame. For efficiency, I cache some results in the middle and remove them at the end.

This is my function:

import polars as pl


def clean_data(df, cols):
    return (
        df.with_columns(pl.mean(col).alias(f"__{col}_mean") for col in cols)
        .with_columns(
            pl.when(pl.col(col) < pl.col(f"__{col}_mean") * 3 / 4)
            .then(pl.col(f"__{col}_mean") * 3 / 4)
            .when(pl.col(col) > pl.col(f"__{col}_mean") * 5 / 4)
            .then(pl.col(f"__{col}_mean") * 5 / 4)
            .otherwise(pl.col(col))
            .alias(col)
            for col in cols
        )
        .select(pl.exclude(f"__{col}_mean" for col in cols))
    )

It works fine for "normal" inputs:

df = pl.DataFrame(
    {
        "a": [1, 2, 3, 4, 5, 12, 28],
        "a2": [1, 2, 3, 4, 5, 6, 7],
    }
)

clean_data(df, ["a", "a2"])

shape: (7, 2)
┌──────────┬─────┐
│ a        ┆ a2  │
│ ---      ┆ --- │
│ f64      ┆ f64 │
╞══════════╪═════╡
│ 5.892857 ┆ 3.0 │
│ 5.892857 ┆ 3.0 │
│ 5.892857 ┆ 3.0 │
│ 5.892857 ┆ 4.0 │
│ 5.892857 ┆ 5.0 │
│ 9.821429 ┆ 5.0 │
│ 9.821429 ┆ 5.0 │
└──────────┴─────┘

However, there is a possibility that the name of my cached columns might conflict with the name of columns existing in the user's inputs, for example:

df = pl.DataFrame(
    {
        "a": [1, 2, 3, 4, 5, 12, 28],
        "a2": [1, 2, 3, 4, 5, 6, 7],
        "__a_mean": [1, 1, 1, 1, 1, 1, 1],
    }
)

clean_data(df, ["a", "a2"])

shape: (7, 2)
┌──────────┬─────┐
│ a        ┆ a2  │
│ ---      ┆ --- │
│ f64      ┆ f64 │
╞══════════╪═════╡
│ 5.892857 ┆ 3.0 │
│ 5.892857 ┆ 3.0 │
│ 5.892857 ┆ 3.0 │
│ 5.892857 ┆ 4.0 │
│ 5.892857 ┆ 5.0 │
│ 9.821429 ┆ 5.0 │
│ 9.821429 ┆ 5.0 │
└──────────┴─────┘

As you can see, the result masked the column __a_mean in the original DataFrame.

Is there a way to append temp columns in the middle of calculations and make sure that generated temp column names do not exist in the original DataFrame?

Alternatively, is there a way to implement my function above without caching any results and without sacrificing performance?


Solution

  • I'm not sure how much overhead this would add:

    You could use .clone() in combination with .update()

    def clean_data(df, cols):
        return (
            df.update(
                df.clone()
                .with_columns(pl.mean(col).alias(f"__{col}_mean") for col in cols)
                .with_columns(
                    pl.when(pl.col(col) < pl.col(f"__{col}_mean") * 3 / 4)
                    .then(pl.col(f"__{col}_mean") * 3 / 4)
                    .when(pl.col(col) > pl.col(f"__{col}_mean") * 5 / 4)
                    .then(pl.col(f"__{col}_mean") * 5 / 4)
                    .otherwise(pl.col(col))
                    .alias(col)
                    for col in cols
                )
                .select(pl.exclude(f"__{col}_mean" for col in cols))
            )
        )
    
    

    The .clone docs says it's a cheap operation, .update performs a .join internally.

    >>> clean_data(df, ["a", "a2"])
    shape: (7, 3)
    ┌──────────┬─────┬──────────┐
    │ a        ┆ a2  ┆ __a_mean │
    │ ---      ┆ --- ┆ ---      │
    │ f64      ┆ f64 ┆ i64      │
    ╞══════════╪═════╪══════════╡
    │ 5.892857 ┆ 3.0 ┆ 1        │
    │ 5.892857 ┆ 3.0 ┆ 1        │
    │ 5.892857 ┆ 3.0 ┆ 1        │
    │ 5.892857 ┆ 4.0 ┆ 1        │
    │ 5.892857 ┆ 5.0 ┆ 1        │
    │ 9.821429 ┆ 5.0 ┆ 1        │
    │ 9.821429 ┆ 5.0 ┆ 1        │
    └──────────┴─────┴──────────┘
    

    Perhaps explictly storing the clone in a variable would make the code more self-documenting.

    def clean_data(df, cols):
        means = [ f"__{col}_mean" for col in cols ]
        mean_cols = dict(zip(cols, means))
        
        formula = lambda x, y: (
           pl.when(x < y * 3 / 4)
             .then(y * 3 / 4)
             .when(x > y * 5 / 4)
             .then(y * 5 / 4)
             .otherwise(x)
        )
    
        mean_df = df.clone()
        mean_df = (
           mean_df
           .with_columns(
              pl.mean(col).alias(mean) for col, mean in mean_cols.items())
           .with_columns(
              formula(x = pl.col(col), y = pl.col(mean)).alias(col) 
              for col, mean in mean_cols.items())
           .select(pl.exclude(*means))
        )
    
        return df.update(mean_df)