pythonseabornkdeplot

Problem with seaborn kdeplot() when plotting two figures side by side


I am trying two plot two 2d distributions together with their marginal distributions on the top and side of the figure like so: enter image description here

Now I wantto combine the above figure with the following figure, such that they appear side by side: enter image description here

However, when doing so, the marginal distributions arent plotted.. Can anyone help? enter image description here

The code for plotting the above figure is given here:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import multivariate_normal
import ot
import ot.plot

# Define the mean and covariance for two different multivariate normal distributions
mean1 = [0, 0]
cov1 = [[1, 0.5], [0.5, 1]]

mean2 = [3, 3]
cov2 = [[1, -0.5], [-0.5, 1]]

n = 100

# Generate random samples from the distributions
np.random.seed(0)
samples1 = np.random.multivariate_normal(mean1, cov1, size=n)
samples2 = np.random.multivariate_normal(mean2, cov2, size=n)

df1 = pd.DataFrame(np.concatenate([samples1, samples2]), columns=['X', 'Y'])
df1['Distribution'] = 'Target'
df1['Distribution'].iloc[n:] = 'Source'

# Create a custom palette with blue and red
custom_palette = {'Target': 'blue', 'Source': 'red'}

# Plotting side by side
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

# Jointplot using seaborn
g = sns.kdeplot(data=df1, x="X", y="Y", hue="Distribution", kind="kde", space=0, fill=True, palette=custom_palette, ax=axs[0])
axs[0].set_xlim(-4, 6.5)
axs[0].set_ylim(-4, 6.5)
# axs[0].set_aspect('equal', adjustable='box')
sns.move_legend(axs[0], "lower right")

# Optimal Transport matching between the samples
a, b = np.ones((n,)) / n, np.ones((n,)) / n  # uniform distribution on samples
M = ot.dist(samples2, samples1, metric='euclidean')
G0 = ot.emd(a, b, M)
ot.plot.plot2D_samples_mat(samples2, samples1, G0, c=[.5, .5, 1])
axs[1].plot(samples2[:, 0], samples2[:, 1], '+r', markersize=10, label='Source samples')  # Increased marker size
axs[1].plot(samples1[:, 0], samples1[:, 1], 'xb', markersize=10, label='Target samples')  # Increased marker size
axs[1].legend(loc=4)

# Common labels and limits
for ax in axs:
    ax.set(xlabel='X')
    ax.set_xlim([-4, 6.5])
    ax.set_ylim([-4, 6.5])

# Remove y-axis from the second figure
axs[1].set(ylabel='')
axs[1].yaxis.set_visible(False)

# Adjust layout and save plot as PDF
fig.tight_layout()

# Show plot
plt.show()

Solution

  • You can't do this directly as the marginal distributions require a jointplot, which is a figure-level plot and cannot directly add extra axes.

    It's however fairly easy to modify the JointGrid code to add more axes.

    The key is to change:

    # add more space to accommodate an extra plot
    # gs = plt.GridSpec(ratio + 1, ratio + 1)
    gs = plt.GridSpec(ratio + 1, ratio + 1 + ratio)
    
    # change how the space is defined (example for the ax_joint)
    # ax_joint = f.add_subplot(gs[1:, :-1])   # use all width but last
    ax_joint = f.add_subplot(gs[1:, :ratio])  # use first "ratio" slots
    

    Which gives us:

    ratio = 5
    space = .2
    
    f = plt.figure(figsize=(12, 4))
    gs = plt.GridSpec(ratio + 1, ratio + 1 + ratio)
    ax_joint = f.add_subplot(gs[1:, :ratio])
    ax_marg_x = f.add_subplot(gs[0, :ratio], sharex=ax_joint)
    ax_marg_y = f.add_subplot(gs[1:, ratio], sharey=ax_joint)
    ax_ot = f.add_subplot(gs[1:, ratio+1:], sharey=ax_joint)
    
    # Turn off tick visibility for the measure axis on the marginal plots
    plt.setp(ax_marg_x.get_xticklabels(), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(), visible=False)
    plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)
    plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
    plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
    plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
    plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
    plt.setp(ax_marg_x.get_yticklabels(), visible=False)
    plt.setp(ax_marg_y.get_xticklabels(), visible=False)
    plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
    plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
    ax_marg_x.yaxis.grid(False)
    ax_marg_y.xaxis.grid(False)
    
    utils = sns.axisgrid.utils
    utils.despine(ax=ax_marg_x, left=True)
    utils.despine(ax=ax_marg_y, bottom=True)
    for axes in [ax_marg_x, ax_marg_y]:
        for axis in [axes.xaxis, axes.yaxis]:
            axis.label.set_visible(False)
    f.tight_layout()
    f.subplots_adjust(hspace=space, wspace=space)
    
    sns.kdeplot(data=df1, x='X', y='Y', hue='Distribution', fill=True, palette=custom_palette, ax=ax_joint)
    sns.move_legend(ax_joint, 'lower right')
    sns.kdeplot(data=df1, x='X', hue='Distribution', fill=True, palette=custom_palette, legend=False,
                ax=ax_marg_x)
    sns.kdeplot(data=df1, y='Y', hue='Distribution', fill=True, palette=custom_palette, legend=False,
                ax=ax_marg_y)
    
    # Optimal Transport matching between the samples
    a, b = np.ones((n,)) / n, np.ones((n,)) / n  # uniform distribution on samples
    M = ot.dist(samples2, samples1, metric='euclidean')
    G0 = ot.emd(a, b, M)
    ot.plot.plot2D_samples_mat(samples2, samples1, G0, c=[.5, .5, 1])
    ax_ot.plot(samples2[:, 0], samples2[:, 1], '+r', markersize=10, label='Source samples')  # Increased marker size
    ax_ot.plot(samples1[:, 0], samples1[:, 1], 'xb', markersize=10, label='Target samples')  # Increased marker size
    ax_ot.legend(loc=4)
    

    Output:

    seaborn jointplot with extra ax

    modifying JointGrid for full flexibility

    Another approach would be to create a subclass of JointGrid that can accept an existing Figure/GridSpec/Axes as input and use those instead of creating their own.

    In the example below, the JointGridCustom class would expect custom_gs=None (default) or custom_gs=(f, gs, ax_joint, ax_marg_x, ax_marg_y) to reuse existing objects. This will allow customization while letting seaborn handle the jointplot:

    f = plt.figure(figsize=(10, 5))
    gs = plt.GridSpec(8, 8)
    ax_joint = f.add_subplot(gs[1:6, :3])
    ax_marg_x = f.add_subplot(gs[0, :3], sharex=ax_joint)
    ax_marg_y = f.add_subplot(gs[1:6, 3], sharey=ax_joint)
    ax_ot = f.add_subplot(gs[1:6, 5:], sharey=ax_joint)
    ax_bottom = f.add_subplot(gs[7:, :], sharey=ax_joint)
    
    g = JointGridCustom(data=df1, x='X', y='Y', hue='Distribution', space=0, palette=custom_palette,
                        custom_gs=(f, gs, ax_joint, ax_marg_x, ax_marg_y)
                       )
    g.plot(sns.kdeplot, sns.kdeplot, fill=True)
    

    Example output:

    custom JointGrid class to reuse existing figure/gridspec/axes

    Full code:

    import matplotlib
    from inspect import signature
    
    from seaborn._base import VectorPlotter, variable_type, categorical_order
    from seaborn._core.data import handle_data_source
    from seaborn._compat import share_axis, get_legend_handles
    from seaborn import utils
    from seaborn.utils import (
        adjust_legend_subtitles,
        set_hls_values,
        _check_argument,
        _draw_figure,
        _disable_autolayout
    )
    from seaborn.palettes import color_palette, blend_palette
    
    class JointGridCustom(sns.JointGrid):
        """Grid for drawing a bivariate plot with marginal univariate plots.
    
        Many plots can be drawn by using the figure-level interface :func:`jointplot`.
        Use this class directly when you need more flexibility.
    
        """
        
        def __init__(
            self, data=None, *,
            x=None, y=None, hue=None,
            height=6, ratio=5, space=.2,
            palette=None, hue_order=None, hue_norm=None,
            dropna=False, xlim=None, ylim=None, marginal_ticks=False,
            custom_gs=None,
        ):
    
            # Set up the subplot grid
            if custom_gs:
                f, gs, ax_joint, ax_marg_x, ax_marg_y = custom_gs
                assert isinstance(f, matplotlib.figure.Figure)
                assert isinstance(gs, matplotlib.gridspec.GridSpec)
                assert isinstance(ax_joint, matplotlib.axes.Axes)
                assert isinstance(ax_marg_x, matplotlib.axes.Axes)
                assert isinstance(ax_marg_y, matplotlib.axes.Axes)
            else:
                f = plt.figure(figsize=(height, height))
                gs = plt.GridSpec(ratio + 1, ratio + 1)
    
                ax_joint = f.add_subplot(gs[1:, :-1])
                ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
                ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)
    
            self._figure = f
            self.ax_joint = ax_joint
            self.ax_marg_x = ax_marg_x
            self.ax_marg_y = ax_marg_y
    
            # Turn off tick visibility for the measure axis on the marginal plots
            plt.setp(ax_marg_x.get_xticklabels(), visible=False)
            plt.setp(ax_marg_y.get_yticklabels(), visible=False)
            plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
            plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)
    
            # Turn off the ticks on the density axis for the marginal plots
            if not marginal_ticks:
                plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
                plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
                plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
                plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
                plt.setp(ax_marg_x.get_yticklabels(), visible=False)
                plt.setp(ax_marg_y.get_xticklabels(), visible=False)
                plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
                plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
                ax_marg_x.yaxis.grid(False)
                ax_marg_y.xaxis.grid(False)
    
            # Process the input variables
            p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue))
            plot_data = p.plot_data.loc[:, p.plot_data.notna().any()]
    
            # Possibly drop NA
            if dropna:
                plot_data = plot_data.dropna()
    
            def get_var(var):
                vector = plot_data.get(var, None)
                if vector is not None:
                    vector = vector.rename(p.variables.get(var, None))
                return vector
    
            self.x = get_var("x")
            self.y = get_var("y")
            self.hue = get_var("hue")
    
            for axis in "xy":
                name = p.variables.get(axis, None)
                if name is not None:
                    getattr(ax_joint, f"set_{axis}label")(name)
    
            if xlim is not None:
                ax_joint.set_xlim(xlim)
            if ylim is not None:
                ax_joint.set_ylim(ylim)
    
            # Store the semantic mapping parameters for axes-level functions
            self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm)
    
            # Make the grid look nice
            utils.despine(f)
            if not marginal_ticks:
                utils.despine(ax=ax_marg_x, left=True)
                utils.despine(ax=ax_marg_y, bottom=True)
            for axes in [ax_marg_x, ax_marg_y]:
                for axis in [axes.xaxis, axes.yaxis]:
                    axis.label.set_visible(False)
            f.tight_layout()
            f.subplots_adjust(hspace=space, wspace=space)
    
    
        
    f = plt.figure(figsize=(10, 5))
    gs = plt.GridSpec(8, 8)
    ax_joint = f.add_subplot(gs[1:6, :3])
    ax_marg_x = f.add_subplot(gs[0, :3], sharex=ax_joint)
    ax_marg_y = f.add_subplot(gs[1:6, 3], sharey=ax_joint)
    ax_ot = f.add_subplot(gs[1:6, 5:], sharey=ax_joint)
    ax_bottom = f.add_subplot(gs[7:, :], sharey=ax_joint)
    
    g = JointGridCustom(data=df1, x='X', y='Y', hue='Distribution', space=0, palette=custom_palette,
                        custom_gs=(f, gs, ax_joint, ax_marg_x, ax_marg_y)
                       )
    g.plot(sns.kdeplot, sns.kdeplot, fill=True)