pythonstringlazy-evaluationpython-polars

Clean string column with messy data into numeric in polars


The following code has been given:


import polars as pl
import numpy as np

# Set the random seed for reproducibility
np.random.seed(0)

# Define the sample size
n = 35000000

# Define the possible values
values_CRP = ["10", "2", "3", None, "<4", ">5"]
values_CRP2 = ["10", "12", "<5", "NA", ">5", "5"]
values_CRP3 = ["10", "12.3", "<5", "NA", ">5.5", "4"]

# Create the DataFrame
df = pl.DataFrame({
    "CRP": np.random.choice(values_CRP, n, replace=True),
    "CRP2": np.random.choice(values_CRP2, n, replace=True),
    "CRP3": np.random.choice(values_CRP3, n, replace=True)
})

Assume that these are columns of 3 different biomarkers. I want to clean them by grabbing the numeric part of each string value, whenever the string does not start with "<" or ">" ("10" should give 10), leave the nulls as they are (meaning the patient did not have any measurement at that time point), and replace the values starting with "<" or ">" with the median of the values which fall below or above the second element of the respective value. For example, "<5" should be replaced with the median of observations with a biomarker value below 5. For ">5", we take observations above. If we have a value of ">10000", and there are no observations above 10000, then we null it. Same with <.

Desired output for minimal example:


df1 = pl.DataFrame({"Current":["<4","3", "2", None, ">5", "10"],
                   "Goal": [2.5,3,2,None,10,10]})

Ideally and because I have 11 columns in reality and almost 40 million rows, I would like to do as much as possible in lazy mode.


Solution

  • I think you would need to join dataframe on itself or run some kind of subquery, and for these task I find DuckDB to be very user-friendly:

    df = df.with_columns(
        num = pl.col.Current.cast(pl.Float64, strict=False)
    )
    
    duckdb.sql("""
        select
            d.Current,
            d.Goal,
            case
                when d.Current[1] == '<' then
                    (select median(tt.num) from df as tt where tt.num < try_cast(d.Current[2:] as float))
                when d.Current[1] == '>' then
                    (select median(tt.num) from df as tt where tt.num > try_cast(d.Current[2:] as float))
                else d.num
            end as Calc
        from df as d
    """).pl()
    
    shape: (6, 3)
    ┌─────────┬──────┬──────┐
    │ Current ┆ Goal ┆ Calc │
    │ ---     ┆ ---  ┆ ---  │
    │ str     ┆ f64  ┆ f64  │
    ╞═════════╪══════╪══════╡
    │ <4      ┆ 2.5  ┆ 2.5  │
    │ 3       ┆ 3.0  ┆ 3.0  │
    │ 2       ┆ 2.0  ┆ 2.0  │
    │ null    ┆ null ┆ null │
    │ >5      ┆ 10.0 ┆ 10.0 │
    │ 10      ┆ 10.0 ┆ 10.0 │
    └─────────┴──────┴──────┘
    

    For pure polars you could probably precalculate median for all unique values first and then just join:

    df_num = df.select(num = pl.col.Current.cast(pl.Float64, strict=False)).drop_nulls()
    df_calc = (
        df
        .filter(pl.col.Current.str.head(1).is_in(["<",">"]))
        .select(
            pl.col.Current,
            oper = pl.col.Current.str.head(1),
            bound = pl.col.Current.str.tail(-1).cast(pl.Float64)
        )
        .unique()
    )
    
    df_mapping = (
        df_num
        .join(df_calc, how="cross")
        .filter(
            ((pl.col.oper == ">") & (pl.col.num > pl.col.bound)) |
            ((pl.col.oper == "<") & (pl.col.num < pl.col.bound))
        )
        .group_by("Current")
        .agg(pl.col.num.median())
    )
    
    (
        df
        .join(df_mapping, on="Current", how="left")
        .with_columns(pl.col.num.fill_null(pl.col.Current))
    )
    
    ape: (6, 3)
    ┌─────────┬──────┬──────┐
    │ Current ┆ Goal ┆ num  │
    │ ---     ┆ ---  ┆ ---  │
    │ str     ┆ f64  ┆ f64  │
    ╞═════════╪══════╪══════╡
    │ <4      ┆ 2.5  ┆ 2.5  │
    │ 3       ┆ 3.0  ┆ 3.0  │
    │ 2       ┆ 2.0  ┆ 2.0  │
    │ null    ┆ null ┆ null │
    │ >5      ┆ 10.0 ┆ 10.0 │
    │ 10      ┆ 10.0 ┆ 10.0 │
    └─────────┴──────┴──────┘
    

    If you want to extend it to multiple columns, you probably can do the same, but unpivot() your DataFrame first, calculate values and then pivot() it back again.

    df = pl.DataFrame({
        "CRP1":["<4","3", "2", None, ">5", "10"],
        "CRP2":["1","2", None, "3", "4", "<2"]
    })
    
    shape: (6, 2)
    ┌──────┬──────┐
    │ CRP1 ┆ CRP2 │
    │ ---  ┆ ---  │
    │ str  ┆ str  │
    ╞══════╪══════╡
    │ <4   ┆ 1    │
    │ 3    ┆ 2    │
    │ 2    ┆ null │
    │ null ┆ 3    │
    │ >5   ┆ 4    │
    │ 10   ┆ <2   │
    └──────┴──────┘
    
    df_unpivot = df.with_row_index().unpivot(index="index")
    
    df_num = (
        df_unpivot
        .select(
            pl.col.variable,
            num = pl.col.value.cast(pl.Float64, strict=False)
        )
        .drop_nulls()
    )
    
    df_calc = (
        df_unpivot
        .filter(pl.col.value.str.head(1).is_in(["<",">"]))
        .select(
            pl.col.variable,
            pl.col.value,
            oper = pl.col.value.str.head(1),
            bound = pl.col.value.str.tail(-1).cast(pl.Float64)
        )
        .unique()
    )
    
    df_mapping = (
        df_calc
        .join(df_num, on="variable", how="inner")
        .filter(
            ((pl.col.oper == ">") & (pl.col.num > pl.col.bound)) |
            ((pl.col.oper == "<") & (pl.col.num < pl.col.bound))
        )
        .group_by("variable","value")
        .agg(pl.col.num.median())
    )
    
    (
        df_unpivot
        .join(df_mapping, on=["variable","value"], how="left")
        .with_columns(pl.col.num.fill_null(pl.col.value).cast(pl.Float64))
        .pivot("variable", index="index", values=["value","num"])
    )
    
    shape: (6, 5)
    ┌───────┬────────────┬────────────┬──────────┬──────────┐
    │ index ┆ value_CRP1 ┆ value_CRP2 ┆ num_CRP1 ┆ num_CRP2 │
    │ ---   ┆ ---        ┆ ---        ┆ ---      ┆ ---      │
    │ u32   ┆ str        ┆ str        ┆ f64      ┆ f64      │
    ╞═══════╪════════════╪════════════╪══════════╪══════════╡
    │ 0     ┆ <4         ┆ 1          ┆ 2.5      ┆ 1.0      │
    │ 1     ┆ 3          ┆ 2          ┆ 3.0      ┆ 2.0      │
    │ 2     ┆ 2          ┆ null       ┆ 2.0      ┆ null     │
    │ 3     ┆ null       ┆ 3          ┆ null     ┆ 3.0      │
    │ 4     ┆ >5         ┆ 4          ┆ 10.0     ┆ 4.0      │
    │ 5     ┆ 10         ┆ <2         ┆ 10.0     ┆ 1.0      │
    └───────┴────────────┴────────────┴──────────┴──────────┘