pythonpandasmulti-index

Why does groupby with dropna=False prevent a subsequent MultiIndex.dropna() to work?


My understanding is MultiIndex.dropna() removes index entries for which at least one level is NaN, there are no conditions. However it seems if a previous groupby was used with dropna=False, it's no longer possible to use MultiIndex.dropna().

(I'm aware the NaN groups would be dropped by groupby without the dropna parameter, but I'm looking for a solution working in the case the parameter has been used at some point earlier).


import pandas as pd
import numpy as np

d = {(8.0, 8.0): {'A': -1.10, 'B': -1.0},
     (7.0, 8.0): {'A': -0.10, 'B': 0.1},
     (5.0, 8.0): {'A': 1.15, 'B': -1.2},
     (7.0, 7.0): {'A': 1.10, 'B': 1.6},
     (7.0, np.NaN): {'A': 0.70, 'B': -0.7},
     (8.0, np.NaN): {'A': -1.00, 'B': 0.9},
     (np.NaN, 5.0): {'A': -2.20, 'B': 1.1}}

# This works as expected
index = pd.MultiIndex.from_tuples(d.keys(), names=['L1', 'L2'])
df = pd.DataFrame(d.values(), index=index)
print(df.index.dropna())

# This doesn't work as expected
df = df.groupby(['L1', 'L2'], dropna=False).mean()
print(df.index.dropna())

MultiIndex([(8.0, 8.0),
            (7.0, 8.0),
            (5.0, 8.0),
            (7.0, 7.0)],
           names=['L1', 'L2'])

MultiIndex([(5.0, 8.0),
            (7.0, 7.0),
            (7.0, 8.0),
            (7.0, nan),
            (8.0, 8.0),
            (8.0, nan),
            (nan, 5.0)],
           names=['L1', 'L2'])

Solution

  • Looking at the sources of pd.MultiIndex.dropna() reveals that there are codes for each value of index.

    Pandas expects code value -1 for NaN, which apparently is not the case when doing .groupby() (a bug?).

    You can avoid this issue by reconstructing the index and then drop NaN values, e.g.:

    df = df.groupby(["L1", "L2"], dropna=False).mean()
    
    # reconstruct the index (this will assign code -1 to NaN):
    df.index = pd.MultiIndex.from_tuples(df.index.to_list(), names=df.index.names)
    
    print(df.index.dropna())
    

    Prints:

    MultiIndex([(5.0, 8.0),
                (7.0, 7.0),
                (7.0, 8.0),
                (8.0, 8.0)],
               names=['L1', 'L2'])