lookuppython-polars

Replace null values in polars with values from dictionary (using other col as key)


I have this code

import polars as pl

def get_month(item_id: int):
    # In practice, fetch month from some DB ...
    return f'2024-{item_id:02.0f}'

df = pl.DataFrame({
    'item_id': [1, 2, 3, 4],
    'month': [None, '2023-07', None, '2023-08']
})

dict_months = {item_id: get_month(item_id) for item_id in df.filter(pl.col('month').is_null())['item_id']}

df.with_columns(pl.when(pl.col('month').is_null())
                .then(pl.col('item_id').map_elements(lambda id: dict_months[id], return_dtype=pl.String).alias('month'))
                .otherwise(pl.col('month')))

Basically, I want to replace all null entries in the month column with values from dict_months using the item_id as key. I can assume that the dict contains keys for all missing ids (by construction), but not the other ids.

When I run the above, I get error PanicException: python function failed KeyError: 2, which seems to imply polars is trying to look up a value for id 2, which it shouldn't since id 2 has a month.

How can this be fixed?


Solution

  • Here's one way using replace_strict with default argument. I also replaced the when/then with coalesce:

    print(
        df.with_columns(
            pl.coalesce(
                "month", pl.col("item_id").replace_strict(dict_months, default=None)
            )
        )
    )
    

    Output:

    shape: (4, 2)
    ┌─────────┬─────────┐
    │ item_id ┆ month   │
    │ ---     ┆ ---     │
    │ i64     ┆ str     │
    ╞═════════╪═════════╡
    │ 1       ┆ 2024-01 │
    │ 2       ┆ 2023-07 │
    │ 3       ┆ 2024-03 │
    │ 4       ┆ 2023-08 │
    └─────────┴─────────┘
    

    Full code:

    import polars as pl
    
    
    def get_month(item_id: int):
        # In practice, fetch month from some DB ...
        return f"2024-{item_id:02.0f}"
    
    
    df = pl.DataFrame(
        {"item_id": [1, 2, 3, 4], "month": [None, "2023-07", None, "2023-08"]}
    )
    
    dict_months = {
        item_id: get_month(item_id)
        for item_id in df.filter(pl.col("month").is_null())["item_id"]
    }
    
    print(
        df.with_columns(
            pl.coalesce(
                "month", pl.col("item_id").replace_strict(dict_months, default=None)
            )
        )
    )