I am using seaborn
and and t-SNE to visualise class separability/overlap and in my dataset containing five classes. My plot is thus a 2x2 subplots. I used the following function which generates the figure below.
def pair_plot_tsne(df):
tsne = TSNE(verbose=1, random_state=234)
df1 = df[(df['mode'] != 'car') & (df['mode'] != 'bus')]
tsne1 = tsne.fit_transform(df1[cols].values) # cols - df's columns list
df1['tsne_one'] = tsne1[:, 0]
df1['tsne-two'] = tsne1[:, 1]
df2 = df[(df['mode'] != 'foot') & (df['mode']!= 'bus')]
tsne2 = tsne.fit_transform(df2[cols].values)
df2['tsne_one'] = tsne2[:, 0]
df2['tsne-two'] = tsne2[:, 1]
df3 = df[df['mode'] != 'car']
tsne3 = tsne.fit_transform(df3[cols].values)
df3['tsne_one'] = tsne3[:, 0]
df3['tsne-two'] = tsne3[:, 1]
df4 = df[df['mode'] != 'foot']
tsne4 = tsne.fit_transform(df4[cols].values)
df4['tsne_one'] = tsne4[:, 0]
df4['tsne-two'] = tsne4[:, 1]
#create figure
f = plt.figure(figsize=(16,4))
ax1 = plt.subplot(2, 2, 1)
sns.scatterplot( #df1 has 3 classes, so 3 colors
x ='tsne_one', y='tsne-two', hue = 'mode', data = df1, palette = sns.color_palette('hls', 3),
legend='full', alpha = 0.7, ax = ax1 )
ax2 = plt.subplot(2, 2, 2)
sns.scatterplot( #df2 has 3 classes, so 3 colors
x ='tsne_one', y='tsne-two', hue = 'mode', data = df2, palette = sns.color_palette('hls', 3),
legend='full', alpha = 0.7, ax = ax2 )
ax3 = plt.subplot(2, 2, 3)
sns.scatterplot( #df3 has 4 classes, so 4 colors
x ='tsne_one', y='tsne-two', hue = 'mode', data = df3, palette = sns.color_palette('hls', 4),
legend='full', alpha = 0.7, ax = ax3 )
ax4 = plt.subplot(2, 2, 4)
sns.scatterplot( #df4 has 4 classes, so 4 colors
x ='tsne_one', y='tsne-two', hue = 'mode', data = df4, palette = sns.color_palette('hls', 4),
legend='full', alpha = 0.7, ax = ax4 )
return f, ax1, ax2, ax3, ax4
Since I'm plotting a subset of the dataset in each subplot, I would like to have the color of each class consistent, in whichever plot it appears. For class, a blue
color for the car
mode in whichever subplot it appears, a black
color for bus
mode in which ever plot it appears, etc...
As it is now, foot
is red in subplot(2, 2, 1)
, and also car
is read in subplot(2, 2, 2)
although the rest are consistent.
For this use case, seaborn allows a dictionary as palette. The dictionary will assign a color to each hue value.
Here is an example of how such a dictionary could be created for your data:
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# create reproducible random data
np.random.seed(2023)
df1 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['foot', 'metro', 'bike'], 10)})
df2 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['car', 'metro', 'bike'], 10)})
df3 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['foot', 'bus', 'metro', 'bike'], 10)})
df4 = pd.DataFrame({'tsne_one': np.random.randn(10),
'tsne-two': np.random.randn(10),
'mode': np.random.choice(['car', 'bus', 'metro', 'bike'], 10)})
# create an array with all unique categories from the 'mode' column in all dataframes
modes = pd.concat([df['mode'] for df in (df1, df2, df3, df4)], ignore_index=True).unique()
# create a color palette for the number of values in modes
colors = sns.color_palette('hls', len(modes))
# create a dictionary of modes and colors
palette = dict(zip(modes, colors))
# create the figure and subplot axes
fig, axs = plt.subplots(2, 2, figsize=(12,6), tight_layout=True)
# zip each dataframe to an axes and iterate through each pair
for df, ax in zip((df1, df2, df3, df4), axs.flatten()):
# plot each data from to a different axes, and color with the custom palette to have the same colors
sns.scatterplot(x='tsne_one', y='tsne-two', hue='mode', data=df, palette=palette, legend='full', alpha=0.7, ax=ax)
plt.show()
You may also want a single figure level legend
from matplotlib.lines import Line2D
fig, axs = plt.subplots(2, 2, figsize=(12,6), tight_layout=True)
for df, ax in zip((df1, df2, df3, df4), axs.flatten()):
sns.scatterplot(x='tsne_one', y='tsne-two', hue='mode', data=df, palette=palette, legend='full', alpha=0.7, ax=ax)
# remove the subplot legend
ax.legend().remove()
# create the custom legend handles
handles = [Line2D([0], [0], marker='o', markerfacecolor=c, linestyle='', markeredgecolor='none') for c in colors]
# create a single figure legend
fig.legend(handles, modes, title='Mode', bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)