pythonmatplotlibtwinx

Shared secondary axes


How to set a shared secondary axes using subplots in matplotlib.

Here is the minimal code to display the issue:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


def countour_every(ax, every, x_data, y_data,
                   color='black', linestyle='-', marker='o', **kwargs):
    """Draw a line with countour marks at each every points"""
    line, = ax.plot(x_data, y_data, linestyle)
    return line


def prettify_axes(ax, data):
    """Makes my plot pretty"""

    if 'title' in data:
        ax.set_title(data['title'])

    if 'y_lim' in data:
        ax.set_ylim(data['y_lim'])

    if 'x_lim' in data:
        ax.set_xlim(data['x_lim'])

    # Draw legend only if labels were set (HOW TO DO IT?)
    # if ax("has_some_label_set"):
    ax.legend(loc='upper right', prop={'size': 6})

    ax.title.set_fontsize(7)
    ax.xaxis.set_tick_params(labelsize=6)
    ax.xaxis.set_tick_params(direction='in')
    ax.xaxis.label.set_size(7)

    ax.yaxis.set_tick_params(labelsize=6)
    ax.yaxis.set_tick_params(direction='in')
    ax.yaxis.label.set_size(7)


def prettify_second_axes(ax):
    ax.yaxis.set_tick_params(labelsize=7)
    ax.yaxis.set_tick_params(labelcolor='red')
    ax.yaxis.label.set_size(7)


def compare_plot(ax, data):
    line1 = countour_every(ax, 10, **data[0])
    if 'label' in data[0]:
        line1.set_label(data[0]['label'])

    line2 = countour_every(ax, 10, **data[1])
    if 'label' in data[1]:
        line2.set_label(data[1]['label'])

    ax2 = ax.twinx()
    line3 = ax.plot(
            data[0]['x_data'],
            data[0]['y_data']-data[1]['y_data'], '-',
            color='red', alpha=.2, zorder=1)

    prettify_axes(ax, data[0])
    prettify_second_axes(ax2)


d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}

fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)

compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])

fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')

fig.savefig('demo.png', dpi=300)

That generates the following image

Shared axes issue

We can see that the X axis and the Y axis is correctly shared, but the secondary twin axis, is repeated in all subplots.

Also the secondary axis isn't scaling correctly to fit the data. (that should occurs independently of the principal y axis being limited).


Solution

  • You will need to share the twin axes manually and also remove the ticklabels

    def compare_plot(ax, data):
        # ...
        ax2 = ax.twinx()
        # ...
        return ax2
    
    sax1 = compare_plot(axes[0][0], [d0, d1])
    sax2 = compare_plot(axes[0][1], [d0, d2])
    sax3 = compare_plot(axes[1][0], [d1, d0])
    sax4 = compare_plot(axes[1][1], [d3, d2])
    
    for sax in [sax2, sax3, sax4]:
        sax1.get_shared_y_axes().join(sax1, sax)
    sax1.autoscale()
    for sax in [sax1,sax3]:
        sax.yaxis.set_tick_params(labelright=False)
    

    enter image description here