pythonmatplotlibseaborncolor-palettepairplot

Set same color palette for multiple plots from several dataframes


I am using seaborn and and t-SNE to visualise class separability/overlap and in my dataset containing five classes. My plot is thus a 2x2 subplots. I used the following function which generates the figure below.

def pair_plot_tsne(df):

    tsne = TSNE(verbose=1, random_state=234) 

    df1 = df[(df['mode'] != 'car') & (df['mode'] != 'bus')]
    tsne1 = tsne.fit_transform(df1[cols].values) # cols - df's columns list
    df1['tsne_one'] = tsne1[:, 0]
    df1['tsne-two'] = tsne1[:, 1]

    df2 = df[(df['mode'] != 'foot') & (df['mode']!= 'bus')]
    tsne2 = tsne.fit_transform(df2[cols].values)
    df2['tsne_one'] = tsne2[:, 0]
    df2['tsne-two'] = tsne2[:, 1]

    df3 = df[df['mode'] != 'car']
    tsne3 = tsne.fit_transform(df3[cols].values)
    df3['tsne_one'] = tsne3[:, 0]
    df3['tsne-two'] = tsne3[:, 1]

    df4 = df[df['mode'] != 'foot']
    tsne4 = tsne.fit_transform(df4[cols].values)
    df4['tsne_one'] = tsne4[:, 0]
    df4['tsne-two'] = tsne4[:, 1]

    #create figure
    f = plt.figure(figsize=(16,4))

    ax1 = plt.subplot(2, 2, 1)
    sns.scatterplot( #df1 has 3 classes, so 3 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df1, palette = sns.color_palette('hls', 3), 
        legend='full', alpha = 0.7, ax = ax1 )

    ax2 = plt.subplot(2, 2, 2)
    sns.scatterplot( #df2 has 3 classes, so 3 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df2, palette = sns.color_palette('hls', 3), 
        legend='full', alpha = 0.7, ax = ax2 )

    ax3 = plt.subplot(2, 2, 3)
    sns.scatterplot( #df3 has 4 classes, so 4 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df3, palette = sns.color_palette('hls', 4), 
        legend='full', alpha = 0.7, ax = ax3 )

    ax4 = plt.subplot(2, 2, 4)
    sns.scatterplot( #df4 has 4 classes, so 4 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df4, palette = sns.color_palette('hls', 4),
        legend='full', alpha = 0.7, ax = ax4 )

    return f, ax1, ax2, ax3, ax4

enter image description here

Since I'm plotting a subset of the dataset in each subplot, I would like to have the color of each class consistent, in whichever plot it appears. For class, a blue color for the car mode in whichever subplot it appears, a black color for bus mode in which ever plot it appears, etc...

As it is now, foot is red in subplot(2, 2, 1), and also car is read in subplot(2, 2, 2) although the rest are consistent.


Solution

  • For this use case, seaborn allows a dictionary as palette. The dictionary will assign a color to each hue value.

    Here is an example of how such a dictionary could be created for your data:

    from matplotlib import pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # create reproducible random data
    np.random.seed(2023)
    df1 = pd.DataFrame({'tsne_one': np.random.randn(10),
                        'tsne-two': np.random.randn(10),
                        'mode': np.random.choice(['foot', 'metro', 'bike'], 10)})
    df2 = pd.DataFrame({'tsne_one': np.random.randn(10),
                        'tsne-two': np.random.randn(10),
                        'mode': np.random.choice(['car', 'metro', 'bike'], 10)})
    df3 = pd.DataFrame({'tsne_one': np.random.randn(10),
                        'tsne-two': np.random.randn(10),
                        'mode': np.random.choice(['foot', 'bus', 'metro', 'bike'], 10)})
    df4 = pd.DataFrame({'tsne_one': np.random.randn(10),
                        'tsne-two': np.random.randn(10),
                        'mode': np.random.choice(['car', 'bus', 'metro', 'bike'], 10)})
    
    # create an array with all unique categories from the 'mode' column in all dataframes 
    modes = pd.concat([df['mode'] for df in (df1, df2, df3, df4)], ignore_index=True).unique()
    
    # create a color palette for the number of values in modes
    colors = sns.color_palette('hls', len(modes))
    
    # create a dictionary of modes and colors
    palette = dict(zip(modes, colors))
    
    # create the figure and subplot axes
    fig, axs = plt.subplots(2, 2, figsize=(12,6), tight_layout=True)
    
    # zip each dataframe to an axes and iterate through each pair
    for df, ax in zip((df1, df2, df3, df4), axs.flatten()):
    
        # plot each data from to a different axes, and color with the custom palette to have the same colors 
        sns.scatterplot(x='tsne_one', y='tsne-two', hue='mode', data=df, palette=palette, legend='full', alpha=0.7, ax=ax)
    plt.show()
    

    enter image description here


    You may also want a single figure level legend

    from matplotlib.lines import Line2D
    
    fig, axs = plt.subplots(2, 2, figsize=(12,6), tight_layout=True)
    
    for df, ax in zip((df1, df2, df3, df4), axs.flatten()):
    
        sns.scatterplot(x='tsne_one', y='tsne-two', hue='mode', data=df, palette=palette, legend='full', alpha=0.7, ax=ax)
        # remove the subplot legend
        ax.legend().remove()
    
    # create the custom legend handles
    handles = [Line2D([0], [0], marker='o', markerfacecolor=c, linestyle='', markeredgecolor='none') for c in colors]
    
    # create a single figure legend
    fig.legend(handles, modes, title='Mode', bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
    

    enter image description here