pythonpandasgroup-bymax

Get all the rows with max value in a pandas dataframe column in Python


I have a pandas dataframe with following columns.

col1  col2  col3   col4
A101  3     LLT    10028980
A101  7     LLT    10028980
A101  7     PT     10028980
A102  5     LLT    10028981
A102  3     PT     10028981 
A103  2     PT     10028982
A103  4     LLT    10028982

I would like to extract all those rows where col2 is max for each value of col1. The expected output is:

col1  col2  col3   col4
A101  7     LLT    10028980
A101  7     PT     10028980
A102  5     LLT    10028981
A103  4     LLT    10028982

I am using following lines of code but it is filtering the rows where there are multiple rows with max value (row 1 is excluded).

m = df.notnull().all(axis=1)

df = df.loc[m].groupby('col1').max().reset_index()

I am getting this output:

col1  col2  col3   col4
A101  7     PT     10028980
A102  5     LLT    10028981
A103  4     LLT    10028982

Solution

  • You can widen the maximum per group with transform, check equality against it to detect maximums and index with that mask:

    >>> df.loc[df["col2"].eq(df.groupby("col1")["col2"].transform("max"))]
    
       col1  col2 col3      col4
    1  A101     7  LLT  10028980
    2  A101     7   PT  10028980
    3  A102     5  LLT  10028981
    6  A103     4  LLT  10028982
    

    here's what .agg would do:

    >>> df.groupby("col1")["col2"].agg("max")
    col1
    A101    7
    A102    5
    A103    4
    Name: col2, dtype: int64
    

    and what .transform does

    >>> df.groupby("col1")["col2"].transform("max")
    
    0    7
    1    7
    2    7
    3    5
    4    5
    5    4
    6    4
    Name: col2, dtype: int64
    

    which allows for an element-wise comparison with the col2 column.