pythonpandasseabornrelplot

Exclude subplots without any data and left-align the rest in relplot


Related to this question: Use relplot to plot a pandas dataframe leading to error

Data for reproducible example is here:

import pandas as pd

data = {'Index': ['TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN', 'TN10p', 'CSU', 'PRCPTOT', 'SDII', 'CWD', 'R99p', 'R99pTOT', 'TX', 'MIN'],
        'Stage': [10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28],
        'Z-Score CEI': [-0.688363146221944, 0.5773502691896258, -0.1132178081286216, -0.4278470185781525, 1.0564189237269357, -0.2085144140570746, -0.2085144140570747, 0.2094308186874662, 0.7196177629619716, 0.0, 0.2085144140570762, -1.3803992008056865, -1.3414801279616884, -0.898669162696764, -0.3015113445777637, -0.2953788838542738, 1.1753566728623484, 0.887285779752818, -0.7071067811865475, 0.2847473987257496, 0.1877402877114761, -0.14246249364941, 0.9686648999069224, -0.3015113445777636, -0.2734952011457535, 0.5888914135578924, -0.4488478006064821, -0.7745966692414834, 0.3052145041378634, 0.8197566686157259, 0.3377616284580471, 1.1832159566199232, -0.3015113445777637, -0.2952684241380082, -0.7971688059921156, 0.4479595231454734, -0.5805577953661853, 0.3015113445777642, -0.610500944190139, -0.7734588159553295, -0.5434722467562666, -0.2085144140570747, -0.2085144140570747, 0.8838570486142397, -0.7976091842744983, 2.213211486674006, 0.3779644730092272, -0.6900911175081499, -0.4856558012299846, -0.6044504143545613, -0.2085144140570746, -0.2085144140570747, 1.6498242899497324, 0.463638205246897, -0.064684622735315, 0.5488212999484522, -0.665392754456709, -1.096398502672124, 0.9387247898517332, -0.2085144140570747, -0.2085144140570748, 1.5486212537866115, 0.6776076459912243, -0.7973761651368712, 0.4773960376293314, 0.2611306759187019, -0.2450438178293888, 0.1097642599896903, -0.2085144140570746, -0.2085144140570747, 1.2468175442040146, 0.4912008775378222, -0.8071397220005339, 0.3015113445777636, -0.4051430868010012, -0.9843673918740764, 0.4231429298696365, -0.2085144140570746, -0.2182178902359924, 1.0617336112420042, 0.4221998839727844, -0.2267786838055363, 0.2847473987257496, 1.2708306299144654, 2.4058495687034616, -0.1042572070285372, 4.79583152331272, 4.79583152331272, -0.1758750648062869, 0.9614146130140746, -0.6493094697110509, 0.2847473987257496, -0.0566333001085325, 0.0970016157961683, -0.3380617018914065, -0.2085144140570746, -0.2132007163556104, 1.6462867435913509, 0.8920062635166146, -0.649519052838329, 0.2847473987257496, -0.5727902328114448, -0.385256843427376, 0.123403510468459, -0.2085144140570747, -0.2085144140570747, 0.7206954054604126, -0.0169294393471337, -0.1547646465068273, 0.3900382256192578, -0.91200685504817, -0.7643838011372592, -0.8553913029328061, -0.2085144140570746, -0.2132007163556104, 1.999517273479448, 0.2135313581345105, 0.3577708763999664, 0.2085144140570741, -0.5245759407883583, -0.3972170332271401, 0.1363988678940945, -0.2085144140570746, -0.2085144140570747, 2.180043023382912, 0.6949201395674811, -0.0345238339879863, 0.3872983346207417, -1.054383845470446, -0.7524909974608698, -0.79555728417573, -0.2085144140570747, -0.2085144140570747, 2.597515932302782, -0.0173575308522844, -0.7839294959021852, 0.5496481403962044, 0.3346732026206391, -0.1729151200242987, 0.8108848540793832, -0.2085144140570747, -0.2085144140570747, -0.1975075078549267, -0.1333012766349092, -0.7300956427599692, 0.3495310368212778, -0.9383516638143292, 0.3757624051611033, -0.9198662110078, -0.2085144140570747, -0.2085144140570747, 0.1077379509580834, -0.0391099277150297, -0.8006407690254357, 0.5226257719601375, 0.2650955994479978, -0.3323178678594628, 1.348187695720845, -0.2085144140570746, -0.2085144140570748, 0.6009413558916348, 0.455353435995126, -0.5933908290969269, 0.0, 0.1226864783178058, -0.0252747129054563, 0.8212299340934688, -0.2085144140570746, -0.2132007163556105, -0.8954835101738379, -1.1134420487718968],
        'Type': ['Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI', 'Cold', 'Heat', 'Rain', 'Rain', 'Rain', 'Rain', 'Rain', 'Temperature', 'VI']}

df = pd.DataFrame(data)

I want to plot the data; rows should be based on the column Type, cols should be based on the column Index, the x-axis should be Z-Score CEI, and the y-axis should be based on Stage column. Currently, I am using relplot to do this:

df = df.groupby('Index').filter(lambda x: not x['Z-Score CEI'].isna().all())
df["Type"] = df["Type"].astype("category")
df["Index"] = df["Index"].astype("category")

df["Type"] = df["Type"].cat.remove_unused_categories()
df["Index"] = df["Index"].cat.remove_unused_categories()

g = sns.relplot(
    data=df,
    x='Z-Score CEI',
    y='Stage',
    col='Index',
    row='Type',
    facet_kws={'sharey': True, 'sharex': True},
    kind='line',
    legend=False,
)

for (i,j,k), data in g.facet_data():
    if data.empty:
        ax = g.facet_axis(i, j)
        ax.set_axis_off()

However, this leads to a plot where the empty plots are distorting the placement of the subplots with data. I want there to be no empty areas.

Current output looks like so: rder

In the graphic above, I want to remove all the subplots which have no data. This will result in different rows having different number of subplots e.g. 1st row might have 5 subplots and 2nd row will have only 4 subplots etc.

I want each row to only have the same Type, not mix multiple Types.


Solution

  • Here is another solution that is based on @mwaskom's suggestion in the comments. The basic idea is to create an auxiliary column where for each Type, existing Index values are labeled 0,1,2,... which will act as the column index in the FacetGrid. Then after plotting the relplot, remove all Axes without data and fix the title of the ones with data by replacing the column index by the "real" Index value.

    # label existing Type-Index pairs
    col_idx = df.value_counts(['Type', 'Index']).groupby(level=0, observed=False).cumcount().astype(str)
    # map the labels back to the dataframe
    df1 = df.merge(col_idx.reset_index(name='column_loc'), on=['Type', 'Index'], how='left')
    
    # plot replot
    g = sns.relplot(
        data=df1,             # <--- new dataframe
        x='Z-Score CEI',
        y='Stage',
        col='column_loc',     # <--- column is by the newly created column
        row='Type',
        facet_kws={'sharey': True, 'sharex': True},
        kind='line',
        legend=False,
    )
    for ax in g.axes.flat:
        if not ax.lines:
            g.fig.delaxes(ax) # remove empty subplots
        else:
            # fix the title
            typ, loc = (x.split(' = ')[1] for x in ax.get_title().split(' | '))
            idx, = col_idx[col_idx==loc].loc[typ].index
            ax.set_title(f"Type = {typ} | Index = {idx}")
    

    result

    I think for this particular task, matplotlib is very easy to use IMO. It's because both Type and Index columns are dtype Categorical, so by passing observed=True to pandas groupby, we can simply drop Index values that don't exist for each Type. Basically, we can use a nested groupby to create a sub-dataframe which can be fed into lineplot. However, because we need to manually plot each lineplot, it may be slow (maybe not since relplot is slow anyway).

    import matplotlib.pyplot as plt
    gby_obj = df.groupby('Type', observed=True)
    nrows = gby_obj.ngroups
    ncols = gby_obj['Index'].nunique().max()
    
    fig, axs = plt.subplots(nrows, ncols, figsize=(20,20), sharey=True, sharex=True)
    for i, (typ, g1) in enumerate(gby_obj):
        for j, (idx, g2) in enumerate(g1.groupby('Index', observed=True)):
            sns.lineplot(data=g2, x='Z-Score CEI', y='Stage', ax=axs[i,j])
            axs[i,j].set_title(f'Type = {typ} | Index = {idx}')
        for a in axs[i,j+1:]:
            fig.delaxes(a)
    sns.despine(fig, top=True, right=True)
    fig.tight_layout()