pandasgroup-byranking

How to efficiently select the top N columns by grouping for each row of a pandas DataFrame?


Let's say that I have a pandas DataFrame representing the coolness score for each "contestant" of a hypothetical contest by date:

import numpy as np
import pandas as pd

rng = np.random.default_rng()
dates = pd.date_range('2024-08-01', '2024-08-07')
contestants = ['Alligator', 'Beryl', 'Chupacabra', 'Dandelion', 'Eggplant', 'Feldspar']
coolness_score = pd.DataFrame(rng.random((len(dates), len(contestants))), index=dates, columns=contestants)
            Alligator     Beryl  Chupacabra  Dandelion  Eggplant  Feldspar
2024-08-01   0.213901  0.952705    0.801651   0.511080  0.662109  0.486296
2024-08-02   0.495700  0.660502    0.379900   0.778438  0.038616  0.214174
2024-08-03   0.639337  0.036226    0.811501   0.281915  0.101850  0.437146
2024-08-04   0.238590  0.686965    0.357087   0.810922  0.907803  0.370247
2024-08-05   0.712564  0.800191    0.040616   0.503644  0.354333  0.742269
2024-08-06   0.916343  0.299557    0.405399   0.851161  0.336570  0.246618
2024-08-07   0.047052  0.645420    0.823397   0.198483  0.368888  0.168188

In addition, each contestant is mapped to a particular category, and limits are imposed on each category:

category_mapping = {
    'Alligator': 'Animal',
    'Beryl': 'Mineral',
    'Chupacabra': 'Animal',
    'Dandelion': 'Vegetable',
    'Eggplant': 'Vegetable',
    'Feldspar': 'Mineral'
}

category_limits = {
    'Animal': 1,
    'Vegetable': 2,
    'Mineral': 1
}

How do I go about selecting which contestants yield the top scores for each category across each date? Specifically considering three scenarios:

  1. The best single score from each category

  2. The best N scores from each category, where N is consistent across all categories

  3. The best scores from each category with limits defined by category_limits

Or, even better, how do I set the scores of the losers to zero?

Scenarios 1 and 2 are clearly subsets of Scenario 3 but I figured there might be some built-in functions that can yield efficiency gains here. If left to my own devices, I'd probably iterate by date, but this seems like it would be the absolute slowest possible approach. Thanks for your help.

Edit 1: In Scenario 3, I mean to apply category-specific limits to each date. So using the example posted above, that would be the top Animal, top two Vegetables, and top Mineral each date.

Edit 2: Amazing answers so far. I should add that I'm also looking for speed, and my actual application is on the order of 250 rows x 15000 columns, with about 175 different categories, and will be run many, many times as part of a Monte Carlo simulation. I'll be testing each solution when I have a clearer mind, but I welcome any discussion about performance at the intended scale. Thanks!


Solution

  • You can use mapping and groupby.rank, then a mask with where:

    # names to categories
    cat = coolness_score.columns.map(category_mapping)
    # categories to limits
    limit = cat.map(category_limits)
    
    # rank per category
    rank = coolness_score.T.groupby(cat).rank(method='dense', ascending=False).T
    # identify top N per category per row
    mask = rank.le(limit)
    
    # mask losers
    out = coolness_score.where(mask, 0)
    

    Output:

                Alligator     Beryl  Chupacabra  Dandelion  Eggplant  Feldspar
    2024-08-01   0.000000  0.878593    0.957980   0.887114  0.266656  0.000000
    2024-08-02   0.000000  0.660319    0.737451   0.921197  0.446438  0.000000
    2024-08-03   0.000000  0.000000    0.765396   0.334504  0.250021  0.736392
    2024-08-04   0.000000  0.000000    0.990308   0.357501  0.124491  0.941783
    2024-08-05   0.327078  0.000000    0.000000   0.309475  0.538202  0.952041
    2024-08-06   0.533576  0.000000    0.000000   0.935781  0.587427  0.690166
    2024-08-07   0.767592  0.000000    0.000000   0.222281  0.879662  0.821808
    

    Intermediates:

    # cat
    Index(['Animal', 'Mineral', 'Animal', 'Vegetable', 'Vegetable', 'Mineral'], dtype='object')
    
    # limit
    Index([1, 1, 1, 2, 2, 1], dtype='int64')
    
    # rank
                Alligator  Beryl  Chupacabra  Dandelion  Eggplant  Feldspar
    2024-08-01        2.0    1.0         1.0        1.0       2.0       2.0
    2024-08-02        2.0    1.0         1.0        1.0       2.0       2.0
    2024-08-03        2.0    2.0         1.0        1.0       2.0       1.0
    2024-08-04        2.0    2.0         1.0        1.0       2.0       1.0
    2024-08-05        1.0    2.0         2.0        2.0       1.0       1.0
    2024-08-06        1.0    2.0         2.0        1.0       2.0       1.0
    2024-08-07        1.0    2.0         2.0        2.0       1.0       1.0
    
    # mask
                Alligator  Beryl  Chupacabra  Dandelion  Eggplant  Feldspar
    2024-08-01      False   True        True       True      True     False
    2024-08-02      False   True        True       True      True     False
    2024-08-03      False  False        True       True      True      True
    2024-08-04      False  False        True       True      True      True
    2024-08-05       True  False       False       True      True      True
    2024-08-06       True  False       False       True      True      True
    2024-08-07       True  False       False       True      True      True