I am trying two plot two 2d distributions together with their marginal distributions on the top and side of the figure like so:
Now I wantto combine the above figure with the following figure, such that they appear side by side:
However, when doing so, the marginal distributions arent plotted.. Can anyone help?
The code for plotting the above figure is given here:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import multivariate_normal
import ot
import ot.plot
# Define the mean and covariance for two different multivariate normal distributions
mean1 = [0, 0]
cov1 = [[1, 0.5], [0.5, 1]]
mean2 = [3, 3]
cov2 = [[1, -0.5], [-0.5, 1]]
n = 100
# Generate random samples from the distributions
np.random.seed(0)
samples1 = np.random.multivariate_normal(mean1, cov1, size=n)
samples2 = np.random.multivariate_normal(mean2, cov2, size=n)
df1 = pd.DataFrame(np.concatenate([samples1, samples2]), columns=['X', 'Y'])
df1['Distribution'] = 'Target'
df1['Distribution'].iloc[n:] = 'Source'
# Create a custom palette with blue and red
custom_palette = {'Target': 'blue', 'Source': 'red'}
# Plotting side by side
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
# Jointplot using seaborn
g = sns.kdeplot(data=df1, x="X", y="Y", hue="Distribution", kind="kde", space=0, fill=True, palette=custom_palette, ax=axs[0])
axs[0].set_xlim(-4, 6.5)
axs[0].set_ylim(-4, 6.5)
# axs[0].set_aspect('equal', adjustable='box')
sns.move_legend(axs[0], "lower right")
# Optimal Transport matching between the samples
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
M = ot.dist(samples2, samples1, metric='euclidean')
G0 = ot.emd(a, b, M)
ot.plot.plot2D_samples_mat(samples2, samples1, G0, c=[.5, .5, 1])
axs[1].plot(samples2[:, 0], samples2[:, 1], '+r', markersize=10, label='Source samples') # Increased marker size
axs[1].plot(samples1[:, 0], samples1[:, 1], 'xb', markersize=10, label='Target samples') # Increased marker size
axs[1].legend(loc=4)
# Common labels and limits
for ax in axs:
ax.set(xlabel='X')
ax.set_xlim([-4, 6.5])
ax.set_ylim([-4, 6.5])
# Remove y-axis from the second figure
axs[1].set(ylabel='')
axs[1].yaxis.set_visible(False)
# Adjust layout and save plot as PDF
fig.tight_layout()
# Show plot
plt.show()
You can't do this directly as the marginal distributions require a jointplot
, which is a figure-level plot and cannot directly add extra axes.
It's however fairly easy to modify the JointGrid
code to add more axes.
The key is to change:
# add more space to accommodate an extra plot
# gs = plt.GridSpec(ratio + 1, ratio + 1)
gs = plt.GridSpec(ratio + 1, ratio + 1 + ratio)
# change how the space is defined (example for the ax_joint)
# ax_joint = f.add_subplot(gs[1:, :-1]) # use all width but last
ax_joint = f.add_subplot(gs[1:, :ratio]) # use first "ratio" slots
Which gives us:
ratio = 5
space = .2
f = plt.figure(figsize=(12, 4))
gs = plt.GridSpec(ratio + 1, ratio + 1 + ratio)
ax_joint = f.add_subplot(gs[1:, :ratio])
ax_marg_x = f.add_subplot(gs[0, :ratio], sharex=ax_joint)
ax_marg_y = f.add_subplot(gs[1:, ratio], sharey=ax_joint)
ax_ot = f.add_subplot(gs[1:, ratio+1:], sharey=ax_joint)
# Turn off tick visibility for the measure axis on the marginal plots
plt.setp(ax_marg_x.get_xticklabels(), visible=False)
plt.setp(ax_marg_y.get_yticklabels(), visible=False)
plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)
plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
plt.setp(ax_marg_x.get_yticklabels(), visible=False)
plt.setp(ax_marg_y.get_xticklabels(), visible=False)
plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
ax_marg_x.yaxis.grid(False)
ax_marg_y.xaxis.grid(False)
utils = sns.axisgrid.utils
utils.despine(ax=ax_marg_x, left=True)
utils.despine(ax=ax_marg_y, bottom=True)
for axes in [ax_marg_x, ax_marg_y]:
for axis in [axes.xaxis, axes.yaxis]:
axis.label.set_visible(False)
f.tight_layout()
f.subplots_adjust(hspace=space, wspace=space)
sns.kdeplot(data=df1, x='X', y='Y', hue='Distribution', fill=True, palette=custom_palette, ax=ax_joint)
sns.move_legend(ax_joint, 'lower right')
sns.kdeplot(data=df1, x='X', hue='Distribution', fill=True, palette=custom_palette, legend=False,
ax=ax_marg_x)
sns.kdeplot(data=df1, y='Y', hue='Distribution', fill=True, palette=custom_palette, legend=False,
ax=ax_marg_y)
# Optimal Transport matching between the samples
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
M = ot.dist(samples2, samples1, metric='euclidean')
G0 = ot.emd(a, b, M)
ot.plot.plot2D_samples_mat(samples2, samples1, G0, c=[.5, .5, 1])
ax_ot.plot(samples2[:, 0], samples2[:, 1], '+r', markersize=10, label='Source samples') # Increased marker size
ax_ot.plot(samples1[:, 0], samples1[:, 1], 'xb', markersize=10, label='Target samples') # Increased marker size
ax_ot.legend(loc=4)
Output:
JointGrid
for full flexibilityAnother approach would be to create a subclass of JointGrid
that can accept an existing Figure/GridSpec/Axes as input and use those instead of creating their own.
In the example below, the JointGridCustom
class would expect custom_gs=None
(default) or custom_gs=(f, gs, ax_joint, ax_marg_x, ax_marg_y)
to reuse existing objects. This will allow customization while letting seaborn handle the jointplot:
f = plt.figure(figsize=(10, 5))
gs = plt.GridSpec(8, 8)
ax_joint = f.add_subplot(gs[1:6, :3])
ax_marg_x = f.add_subplot(gs[0, :3], sharex=ax_joint)
ax_marg_y = f.add_subplot(gs[1:6, 3], sharey=ax_joint)
ax_ot = f.add_subplot(gs[1:6, 5:], sharey=ax_joint)
ax_bottom = f.add_subplot(gs[7:, :], sharey=ax_joint)
g = JointGridCustom(data=df1, x='X', y='Y', hue='Distribution', space=0, palette=custom_palette,
custom_gs=(f, gs, ax_joint, ax_marg_x, ax_marg_y)
)
g.plot(sns.kdeplot, sns.kdeplot, fill=True)
Example output:
Full code:
import matplotlib
from inspect import signature
from seaborn._base import VectorPlotter, variable_type, categorical_order
from seaborn._core.data import handle_data_source
from seaborn._compat import share_axis, get_legend_handles
from seaborn import utils
from seaborn.utils import (
adjust_legend_subtitles,
set_hls_values,
_check_argument,
_draw_figure,
_disable_autolayout
)
from seaborn.palettes import color_palette, blend_palette
class JointGridCustom(sns.JointGrid):
"""Grid for drawing a bivariate plot with marginal univariate plots.
Many plots can be drawn by using the figure-level interface :func:`jointplot`.
Use this class directly when you need more flexibility.
"""
def __init__(
self, data=None, *,
x=None, y=None, hue=None,
height=6, ratio=5, space=.2,
palette=None, hue_order=None, hue_norm=None,
dropna=False, xlim=None, ylim=None, marginal_ticks=False,
custom_gs=None,
):
# Set up the subplot grid
if custom_gs:
f, gs, ax_joint, ax_marg_x, ax_marg_y = custom_gs
assert isinstance(f, matplotlib.figure.Figure)
assert isinstance(gs, matplotlib.gridspec.GridSpec)
assert isinstance(ax_joint, matplotlib.axes.Axes)
assert isinstance(ax_marg_x, matplotlib.axes.Axes)
assert isinstance(ax_marg_y, matplotlib.axes.Axes)
else:
f = plt.figure(figsize=(height, height))
gs = plt.GridSpec(ratio + 1, ratio + 1)
ax_joint = f.add_subplot(gs[1:, :-1])
ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)
self._figure = f
self.ax_joint = ax_joint
self.ax_marg_x = ax_marg_x
self.ax_marg_y = ax_marg_y
# Turn off tick visibility for the measure axis on the marginal plots
plt.setp(ax_marg_x.get_xticklabels(), visible=False)
plt.setp(ax_marg_y.get_yticklabels(), visible=False)
plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)
# Turn off the ticks on the density axis for the marginal plots
if not marginal_ticks:
plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
plt.setp(ax_marg_x.get_yticklabels(), visible=False)
plt.setp(ax_marg_y.get_xticklabels(), visible=False)
plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
ax_marg_x.yaxis.grid(False)
ax_marg_y.xaxis.grid(False)
# Process the input variables
p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue))
plot_data = p.plot_data.loc[:, p.plot_data.notna().any()]
# Possibly drop NA
if dropna:
plot_data = plot_data.dropna()
def get_var(var):
vector = plot_data.get(var, None)
if vector is not None:
vector = vector.rename(p.variables.get(var, None))
return vector
self.x = get_var("x")
self.y = get_var("y")
self.hue = get_var("hue")
for axis in "xy":
name = p.variables.get(axis, None)
if name is not None:
getattr(ax_joint, f"set_{axis}label")(name)
if xlim is not None:
ax_joint.set_xlim(xlim)
if ylim is not None:
ax_joint.set_ylim(ylim)
# Store the semantic mapping parameters for axes-level functions
self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm)
# Make the grid look nice
utils.despine(f)
if not marginal_ticks:
utils.despine(ax=ax_marg_x, left=True)
utils.despine(ax=ax_marg_y, bottom=True)
for axes in [ax_marg_x, ax_marg_y]:
for axis in [axes.xaxis, axes.yaxis]:
axis.label.set_visible(False)
f.tight_layout()
f.subplots_adjust(hspace=space, wspace=space)
f = plt.figure(figsize=(10, 5))
gs = plt.GridSpec(8, 8)
ax_joint = f.add_subplot(gs[1:6, :3])
ax_marg_x = f.add_subplot(gs[0, :3], sharex=ax_joint)
ax_marg_y = f.add_subplot(gs[1:6, 3], sharey=ax_joint)
ax_ot = f.add_subplot(gs[1:6, 5:], sharey=ax_joint)
ax_bottom = f.add_subplot(gs[7:, :], sharey=ax_joint)
g = JointGridCustom(data=df1, x='X', y='Y', hue='Distribution', space=0, palette=custom_palette,
custom_gs=(f, gs, ax_joint, ax_marg_x, ax_marg_y)
)
g.plot(sns.kdeplot, sns.kdeplot, fill=True)