pythondataframesortingsumpython-polars

How to perform row aggregation across the largest x columns in a polars data frame?


I have a data frame with 6 value columns and I want to sum the largest 3 of them. I also want to create an ID matrix to identify which columns were included in the sum.

So the initial data frame may be something like this:

df = pl.DataFrame({
        'id_col': [0,1,2,3,4],
        'val1': [10,0,0,20,5],
        'val2': [5,1,2,3,10],
        'val3': [8,2,2,2,5],
        'val4': [1,7,7,4,1],
        'val5': [3,0,0,6,0],
        'val6': [2,7,5,5,4]
})

and then the output would look like this:

df = pl.DataFrame({
        'id_col': [0,1,2,3,4],
        'val1': [1,0,0,1,1],
        'val2': [1,0,1,0,1],
        'val3': [1,1,0,0,1],
        'val4': [0,1,1,0,0],
        'val5': [0,0,0,1,0],
        'val6': [0,1,1,1,0],
        'agg_col': [23,16,14,31,20]
})

Note that there was a tie for third place in the third row and it can just be arbitrarily decided which val column gets credit for the submission to the sum.

I have tried concatenating the columns into a list and sorting them but I'm having trouble manipulating the list. I thought maybe I could take the top three from the list and sum them and then perform a row-wise check to see if the original columns were in the list.

df.with_columns(pl.concat_list(pl.col(val_cols)).list.sort().alias('val_list')

I have tried making use of top_k_by, cut, and slice but can't quite get it.


Solution

  • Here are the steps:

    1. unpivot the val columns
    2. for each id_col group,
      • sum the largest 3 columns using pl.col("value").top_k(3).sum()
      • get a list of the names of those columns using pl.col("variable").top_k_by("value", k=3)
    3. Construct the flag columns (row-wise check if each column is in list of the top 3) using [pl.lit(col).is_in("variable").cast(pl.Int8).alias(col) for col in val_cols]

    Solution:

    import polars as pl
    import polars.selectors as cs
    
    df = pl.DataFrame(
        {
            "id_col": [0, 1, 2, 3, 4],
            "val1": [10, 0, 0, 20, 5],
            "val2": [5, 1, 2, 3, 10],
            "val3": [8, 2, 2, 2, 5],
            "val4": [1, 7, 7, 4, 1],
            "val5": [3, 0, 0, 6, 0],
            "val6": [2, 7, 5, 5, 4],
        }
    )
    
    val_cols = cs.expand_selector(df, cs.starts_with("val"))
    
    res = (
        df.unpivot(
            cs.starts_with("val"),
            index="id_col",
        )
        .group_by("id_col")
        .agg(
            pl.col("value").top_k(3).sum().alias("agg_col"),
            pl.col("variable").top_k_by("value", k=3),
        )
        .select(
            "id_col",
            *[pl.lit(col).is_in("variable").cast(pl.Int8).alias(col) for col in val_cols],
            "agg_col",
        )
        .sort("id_col")
    )
    

    Output:

    
    >>> res 
    
    shape: (5, 8)
    ┌────────┬──────┬──────┬──────┬──────┬──────┬──────┬─────────┐
    │ id_col ┆ val1 ┆ val2 ┆ val3 ┆ val4 ┆ val5 ┆ val6 ┆ agg_col │
    │ ---    ┆ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---     │
    │ i64    ┆ i8   ┆ i8   ┆ i8   ┆ i8   ┆ i8   ┆ i8   ┆ i64     │
    ╞════════╪══════╪══════╪══════╪══════╪══════╪══════╪═════════╡
    │ 0      ┆ 1    ┆ 1    ┆ 1    ┆ 0    ┆ 0    ┆ 0    ┆ 23      │
    │ 1      ┆ 0    ┆ 0    ┆ 1    ┆ 1    ┆ 0    ┆ 1    ┆ 16      │
    │ 2      ┆ 0    ┆ 1    ┆ 0    ┆ 1    ┆ 0    ┆ 1    ┆ 14      │
    │ 3      ┆ 1    ┆ 0    ┆ 0    ┆ 0    ┆ 1    ┆ 1    ┆ 31      │
    │ 4      ┆ 1    ┆ 1    ┆ 1    ┆ 0    ┆ 0    ┆ 0    ┆ 20      │
    └────────┴──────┴──────┴──────┴──────┴──────┴──────┴─────────┘
    
    # Output of the first step
    >>> df.unpivot(cs.starts_with("val"), index="id_col")
    
    shape: (30, 3)
    ┌────────┬──────────┬───────┐
    │ id_col ┆ variable ┆ value │
    │ ---    ┆ ---      ┆ ---   │
    │ i64    ┆ str      ┆ i64   │
    ╞════════╪══════════╪═══════╡
    │ 0      ┆ val1     ┆ 10    │
    │ 1      ┆ val1     ┆ 0     │
    │ 2      ┆ val1     ┆ 0     │
    │ 3      ┆ val1     ┆ 20    │
    │ 4      ┆ val1     ┆ 5     │
    │ …      ┆ …        ┆ …     │
    │ 0      ┆ val6     ┆ 2     │
    │ 1      ┆ val6     ┆ 7     │
    │ 2      ┆ val6     ┆ 5     │
    │ 3      ┆ val6     ┆ 5     │
    │ 4      ┆ val6     ┆ 4     │
    └────────┴──────────┴───────┘