Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Shared secondary axes

How to set a shared secondary axes using subplots in matplotlib.

Here is the minimal code to display the issue:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


def countour_every(ax, every, x_data, y_data,
                   color='black', linestyle='-', marker='o', **kwargs):
    """Draw a line with countour marks at each every points"""
    line, = ax.plot(x_data, y_data, linestyle)
    return line


def prettify_axes(ax, data):
    """Makes my plot pretty"""

    if 'title' in data:
        ax.set_title(data['title'])

    if 'y_lim' in data:
        ax.set_ylim(data['y_lim'])

    if 'x_lim' in data:
        ax.set_xlim(data['x_lim'])

    # Draw legend only if labels were set (HOW TO DO IT?)
    # if ax("has_some_label_set"):
    ax.legend(loc='upper right', prop={'size': 6})

    ax.title.set_fontsize(7)
    ax.xaxis.set_tick_params(labelsize=6)
    ax.xaxis.set_tick_params(direction='in')
    ax.xaxis.label.set_size(7)

    ax.yaxis.set_tick_params(labelsize=6)
    ax.yaxis.set_tick_params(direction='in')
    ax.yaxis.label.set_size(7)


def prettify_second_axes(ax):
    ax.yaxis.set_tick_params(labelsize=7)
    ax.yaxis.set_tick_params(labelcolor='red')
    ax.yaxis.label.set_size(7)


def compare_plot(ax, data):
    line1 = countour_every(ax, 10, **data[0])
    if 'label' in data[0]:
        line1.set_label(data[0]['label'])

    line2 = countour_every(ax, 10, **data[1])
    if 'label' in data[1]:
        line2.set_label(data[1]['label'])

    ax2 = ax.twinx()
    line3 = ax.plot(
            data[0]['x_data'],
            data[0]['y_data']-data[1]['y_data'], '-',
            color='red', alpha=.2, zorder=1)

    prettify_axes(ax, data[0])
    prettify_second_axes(ax2)


d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}

fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)

compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])

fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')

fig.savefig('demo.png', dpi=300)

That generates the following image

Shared axes issue

We can see that the X axis and the Y axis is correctly shared, but the secondary twin axis, is repeated in all subplots.

Also the secondary axis isn't scaling correctly to fit the data. (that should occurs independently of the principal y axis being limited).

like image 412
Lin Avatar asked Mar 06 '26 13:03

Lin


1 Answers

You will need to share the twin axes manually and also remove the ticklabels

def compare_plot(ax, data):
    # ...
    ax2 = ax.twinx()
    # ...
    return ax2

sax1 = compare_plot(axes[0][0], [d0, d1])
sax2 = compare_plot(axes[0][1], [d0, d2])
sax3 = compare_plot(axes[1][0], [d1, d0])
sax4 = compare_plot(axes[1][1], [d3, d2])

for sax in [sax2, sax3, sax4]:
    sax1.get_shared_y_axes().join(sax1, sax)
sax1.autoscale()
for sax in [sax1,sax3]:
    sax.yaxis.set_tick_params(labelright=False)

enter image description here

like image 193
ImportanceOfBeingErnest Avatar answered Mar 08 '26 04:03

ImportanceOfBeingErnest



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!