How to set a shared secondary axes using subplots in matplotlib.
Here is the minimal code to display the issue:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
def countour_every(ax, every, x_data, y_data,
color='black', linestyle='-', marker='o', **kwargs):
"""Draw a line with countour marks at each every points"""
line, = ax.plot(x_data, y_data, linestyle)
return line
def prettify_axes(ax, data):
"""Makes my plot pretty"""
if 'title' in data:
ax.set_title(data['title'])
if 'y_lim' in data:
ax.set_ylim(data['y_lim'])
if 'x_lim' in data:
ax.set_xlim(data['x_lim'])
# Draw legend only if labels were set (HOW TO DO IT?)
# if ax("has_some_label_set"):
ax.legend(loc='upper right', prop={'size': 6})
ax.title.set_fontsize(7)
ax.xaxis.set_tick_params(labelsize=6)
ax.xaxis.set_tick_params(direction='in')
ax.xaxis.label.set_size(7)
ax.yaxis.set_tick_params(labelsize=6)
ax.yaxis.set_tick_params(direction='in')
ax.yaxis.label.set_size(7)
def prettify_second_axes(ax):
ax.yaxis.set_tick_params(labelsize=7)
ax.yaxis.set_tick_params(labelcolor='red')
ax.yaxis.label.set_size(7)
def compare_plot(ax, data):
line1 = countour_every(ax, 10, **data[0])
if 'label' in data[0]:
line1.set_label(data[0]['label'])
line2 = countour_every(ax, 10, **data[1])
if 'label' in data[1]:
line2.set_label(data[1]['label'])
ax2 = ax.twinx()
line3 = ax.plot(
data[0]['x_data'],
data[0]['y_data']-data[1]['y_data'], '-',
color='red', alpha=.2, zorder=1)
prettify_axes(ax, data[0])
prettify_second_axes(ax2)
d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)
compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])
fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')
fig.savefig('demo.png', dpi=300)
That generates the following image
We can see that the X axis and the Y axis is correctly shared, but the secondary twin axis, is repeated in all subplots.
Also the secondary axis isn't scaling correctly to fit the data. (that should occurs independently of the principal y axis being limited).
You will need to share the twin axes manually and also remove the ticklabels
def compare_plot(ax, data):
# ...
ax2 = ax.twinx()
# ...
return ax2
sax1 = compare_plot(axes[0][0], [d0, d1])
sax2 = compare_plot(axes[0][1], [d0, d2])
sax3 = compare_plot(axes[1][0], [d1, d0])
sax4 = compare_plot(axes[1][1], [d3, d2])
for sax in [sax2, sax3, sax4]:
sax1.get_shared_y_axes().join(sax1, sax)
sax1.autoscale()
for sax in [sax1,sax3]:
sax.yaxis.set_tick_params(labelright=False)