window-functionslazy-evaluationpython-polars

Aggregation in Polars window functions - how to select the top value based on an aggregation from other column


I have a large dataset on ocean freight that includes columns for bol, voyage_id, carrier, and total containers (teus), similar to this:

lf = pl.LazyFrame({
    'bol_id':(1,2,3,4,5,6,7,8,9),
    'voyage_id':(1,1,1,2,2,2,3,3,3),
    'carrier_scac':('mscu', 'mscu', 'hpld', 'hpld', 'hpld', 'hpld', 'ever', 'mscu', 'ever'),
    'teus':(20, 40, 5, 10, 25, 20, 5, 45, 5)
})
print(lf.collect())
┌────────┬───────────┬──────────────┬──────┐
│ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus │
│ ---    ┆ ---       ┆ ---          ┆ ---  │
│ i64    ┆ i64       ┆ str          ┆ i64  │
╞════════╪═══════════╪══════════════╪══════╡
│ 1      ┆ 1         ┆ mscu         ┆ 20   │
│ 2      ┆ 1         ┆ mscu         ┆ 40   │
│ 3      ┆ 1         ┆ hpld         ┆ 5    │
│ 4      ┆ 2         ┆ hpld         ┆ 10   │
│ 5      ┆ 2         ┆ hpld         ┆ 25   │
│ 6      ┆ 2         ┆ hpld         ┆ 20   │
│ 7      ┆ 3         ┆ ever         ┆ 5    │
│ 8      ┆ 3         ┆ mscu         ┆ 45   │
│ 9      ┆ 3         ┆ ever         ┆ 5    │
└────────┴───────────┴──────────────┴──────┘

For each voyage, I want to get the carrier with the highest sum of teus. I can do this by a group_by followed by a join, but I'd like to do this with a window function and can't quite figure out the syntax/logic in Polars (0.20).

Current working function:

def add_primary_carrier(lf):
    lf2 = (
        lf
        #select relevant cols
        .select('voyage_id', 'carrier_scac', 'teus')
        #ignore bols with missing data
        .drop_nulls()
        #sum up TEUs by voyage and carrier
        .group_by('voyage_id', 'carrier_scac')
        .agg(pl.col('teus').sum().alias('sum_teus'))
        #choose the carrier with the most TEUs on each voyage
        .sort('sum_teus', descending=True)
        .group_by('voyage_id')
        .agg(pl.col('carrier_scac').first().alias('primary_scac'))
    )
    lf = (
        #add primary scac column to main lf
        lf.join(lf2, how='left', on='voyage_id')
    )

But it seems a window function would be a lot cleaner (and perhaps less resource-intensive). Something like:

def add_primary_carrier_window(lf):
    lf = (
        lf.with_columns(
            pl.col('carrier_scac')
            .sort_by(pl.col('teus').sum().over('carrier_scac'), descending=True)
            .drop_nulls().first()
            .over('voyage_id')
            .alias('primary_scac')
        )
    )
    return lf

But that function throws a "window expression not allowed in aggregation" OperationError.

Thanks in advance for the help!

Expected output:

┌────────┬───────────┬──────────────┬──────┬──────────────┬──────────────┐
│ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus ┆ primary_scac ┆ shared_cargo │
│ ---    ┆ ---       ┆ ---          ┆ ---  ┆ ---          ┆ ---          │
│ i64    ┆ i64       ┆ str          ┆ i64  ┆ str          ┆ bool         │
╞════════╪═══════════╪══════════════╪══════╪══════════════╪══════════════╡
│ 1      ┆ 1         ┆ mscu         ┆ 20   ┆ mscu         ┆ false        │
│ 2      ┆ 1         ┆ mscu         ┆ 40   ┆ mscu         ┆ false        │
│ 3      ┆ 1         ┆ hpld         ┆ 5    ┆ mscu         ┆ true         │
│ 4      ┆ 2         ┆ hpld         ┆ 10   ┆ hpld         ┆ false        │
│ 5      ┆ 2         ┆ hpld         ┆ 25   ┆ hpld         ┆ false        │
│ 6      ┆ 2         ┆ hpld         ┆ 20   ┆ hpld         ┆ false        │
│ 7      ┆ 3         ┆ ever         ┆ 5    ┆ mscu         ┆ true         │
│ 8      ┆ 3         ┆ mscu         ┆ 45   ┆ mscu         ┆ false        │
│ 9      ┆ 3         ┆ ever         ┆ 5    ┆ mscu         ┆ true         │
└────────┴───────────┴──────────────┴──────┴──────────────┴──────────────┘

Solution

  • There are a few issues on the tracker regarding it e.g. https://github.com/pola-rs/polars/issues/14361

    You basically have to create a column from each .over "aggregation" in separate .with_columns calls as they cannot be "nested".

    (df.with_columns( 
        pl.col('teus')
          .sum()
          .over('voyage_id', 'carrier_scac')
          .alias('sum_teus')
       )
       .with_columns(
          pl.col('carrier_scac') 
            .sort_by('sum_teus', descending=True)
            .first()
            .over('voyage_id')
            .alias('primary_scac')
       )
    )
    
    shape: (9, 6)
    ┌────────┬───────────┬──────────────┬──────┬──────────┬──────────────┐
    │ bol_id ┆ voyage_id ┆ carrier_scac ┆ teus ┆ sum_teus ┆ primary_scac │
    │ ---    ┆ ---       ┆ ---          ┆ ---  ┆ ---      ┆ ---          │
    │ i64    ┆ i64       ┆ str          ┆ i64  ┆ i64      ┆ str          │
    ╞════════╪═══════════╪══════════════╪══════╪══════════╪══════════════╡
    │ 1      ┆ 1         ┆ mscu         ┆ 20   ┆ 60       ┆ mscu         │
    │ 2      ┆ 1         ┆ mscu         ┆ 40   ┆ 60       ┆ mscu         │
    │ 3      ┆ 1         ┆ hpld         ┆ 5    ┆ 5        ┆ mscu         │
    │ 4      ┆ 2         ┆ hpld         ┆ 10   ┆ 55       ┆ hpld         │
    │ 5      ┆ 2         ┆ hpld         ┆ 25   ┆ 55       ┆ hpld         │
    │ 6      ┆ 2         ┆ hpld         ┆ 20   ┆ 55       ┆ hpld         │
    │ 7      ┆ 3         ┆ ever         ┆ 5    ┆ 10       ┆ mscu         │
    │ 8      ┆ 3         ┆ mscu         ┆ 45   ┆ 45       ┆ mscu         │
    │ 9      ┆ 3         ┆ ever         ┆ 5    ┆ 10       ┆ mscu         │
    └────────┴───────────┴──────────────┴──────┴──────────┴──────────────┘