Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matplotlib: align origin of right axis with specific left axis value

When plotting several y axis in Matplotlib, is there a way to specify how to align the origin (and/or some ytick labels) of the right axis with a specific value of the left axis?

Here is my problem: I would like to plot two set of data as well as their difference (basically, I am trying to reproduce this kind of graph).

I can reproduce it, but I have to manually adjust the ylim of the right axis so that the origin is aligned with the value I want from the left axis.

I putted below an example of a simplified version of the code I use. As you can see, I have to manually adjust scale of the right axis to align the origin of the right axis as well as the square.

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

grp1 = np.array([1.202, 1.477, 1.223, 1.284, 1.701, 1.724, 1.099,
                1.242, 1.099, 1.217, 1.291, 1.305, 1.333, 1.246])
grp2 = np.array([1.802, 2.399, 2.559, 2.286, 2.460, 2.511, 2.296,
                1.975])

fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.17, 0.13, 0.6, 0.7])

# remove top and right spines and turn ticks off if no spine
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('left')
# postition of tick out
ax.tick_params(axis='both', direction='out', width=3, length=7,
        labelsize=24, pad=8)
ax.spines['left'].set_linewidth(3)

# plot groups vs random numbers to create dot plot
ax.plot(np.random.normal(1, 0.05, grp2.size), grp2, 'ok', markersize=10)
ax.plot(np.random.normal(2, 0.05, grp1.size), grp1, 'ok', markersize=10)
ax.errorbar(1, np.mean(grp2), fmt='_r', markersize=50,
        markeredgewidth=3)
ax.errorbar(2, np.mean(grp1), fmt='_r', markersize=50,
        markeredgewidth=3)


ax.set_xlim((0.5, 3.5))
ax.set_ylim((0, 2.7))

# create right axis
ax2 = fig.add_axes(ax.get_position(), sharex=ax, frameon=False)
ax2.spines['left'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.spines['bottom'].set_color('none')
ax2.xaxis.set_ticks_position('none')
ax2.yaxis.set_ticks_position('right')
# postition of tick out
ax2.tick_params(axis='both', direction='out', width=3, length=7,
        labelsize=24, pad=8)
ax2.spines['right'].set_linewidth(3)
ax2.set_xticks([1, 2, 3])
ax2.set_xticklabels(('gr2', 'gr1', 'D'))
ax2.hlines(0, 0.5, 3.5, linestyle='dotted')
#ax2.hlines((np.mean(adult)-np.mean(nrvm)), 0, 3.5, linestyle='dotted')

ax2.plot(3, (np.mean(grp1)-np.mean(grp2)), 'sk', markersize=12)

# manual adjustment so the origin is aligned width left group2
ax2.set_ylim((-2.3, 0.42))
ax2.set_xlim((0.5, 3.5))

plt.show()
like image 662
gcalmettes Avatar asked Dec 21 '22 08:12

gcalmettes


2 Answers

You can make a little function that calculates the alignment of ax2:

def align_yaxis(ax1, v1, ax2, v2):
    """adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
    _, y1 = ax1.transData.transform((0, v1))
    _, y2 = ax2.transData.transform((0, v2))
    inv = ax2.transData.inverted()
    _, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
    miny, maxy = ax2.get_ylim()
    ax2.set_ylim(miny+dy, maxy+dy)

by using align_yaxis(), you can align the axes quickly:

#...... your code

# adjustment so the origin is aligned width left group2
ax2.set_ylim((0, 2.7))
align_yaxis(ax, np.mean(grp2), ax2, 0)
plt.show()
like image 127
HYRY Avatar answered Dec 26 '22 10:12

HYRY


The above answer is Okay, but sometimes cuts out data, it is more thoroughly answered in the second answer here,

Matplotlib axis with two scales shared origin

or with a quick hack

def align_yaxis(ax1, v1, ax2, v2, y2min, y2max):
    """adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1."""

    """where y2max is the maximum value in your secondary plot. I haven't
     had a problem with minimum values being cut, so haven't set this. This
     approach doesn't necessarily make for axis limits at nice near units,
     but does optimist plot space"""

    _, y1 = ax1.transData.transform((0, v1))
    _, y2 = ax2.transData.transform((0, v2))
    inv = ax2.transData.inverted()
    _, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
    miny, maxy = ax2.get_ylim()
    scale = 1
    while scale*(maxy+dy) < y2max:
        scale += 0.05
    ax2.set_ylim(scale*(miny+dy), scale*(maxy+dy))
like image 39
Paul Brogan Avatar answered Dec 26 '22 12:12

Paul Brogan