Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How draw box across multiple axes on matplotlib using ax position as reference

I would like to draw a box across multiple axes, using one ax coordinates as reference. The simple code I have, that does not generate the box is

import matplotlib.pyplot as plt
import numpy as np

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=False, sharey=False, figsize=(15,9))

x = 2 * np.pi * np.arange(1000) / 1000 
y1 = np.sin(x)
y2 = np.cos(x)

ax1.plot(x,y1)
ax2.plot(x,y2)
plt.show()

This generate the following figure:

enter image description here

What I would like to have is the following figure, using x cordinates from ax2 to specify the position:

enter image description here

like image 258
rmelo Avatar asked Apr 23 '19 15:04

rmelo


People also ask

How do I plot multiple figures in Matplotlib?

In Matplotlib, we can draw multiple graphs in a single plot in two ways. One is by using subplot() function and other by superimposition of second graph on the first i.e, all graphs will appear on the same plot.


2 Answers

The question is a bit what purpose the rectangle should serve. If it is simply a rectangle bound to ax2 but extending up to the upper edge of ax1 a rectangle can be created like

import matplotlib.pyplot as plt
import numpy as np

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=False, sharey=False, figsize=(15,9))

x = 2 * np.pi * np.arange(1000) / 1000 
y1 = np.sin(x)
y2 = np.cos(x)

ax1.plot(x,y1)
ax2.plot(x,y2)

rect = plt.Rectangle((1,0), width=1, height=2+fig.subplotpars.wspace,
                     transform=ax2.get_xaxis_transform(), clip_on=False,
                     edgecolor="k", facecolor="none", linewidth=3)
ax2.add_patch(rect)

plt.show()

enter image description here

But that will of course stay where it is, even if the limits of ax1 change. Is that desired?

So maybe a more interesting solution is one where the rectangle follows the coordinates in both axes. The following would only work in matplotlib 3.1, which is as of today only available as prerelease
(pip install --pre --upgrade matplotlib)

import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import numpy as np

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=False, sharey=False, figsize=(15,9))

x = 2 * np.pi * np.arange(1000) / 1000 
y1 = np.sin(x)
y2 = np.cos(x)

ax1.plot(x,y1)
ax2.plot(x,y2)

def rectspan(x1, x2, ax1, ax2, **kwargs):
    line1, = ax1.plot([x1, x1, x2, x2],[0,1,1,0], 
                      transform=ax1.get_xaxis_transform(), **kwargs)
    line2, = ax2.plot([x1, x1, x2, x2],[1,0,0,1], 
                      transform=ax2.get_xaxis_transform(), **kwargs)
    for x in (x1, x2):
        p = ConnectionPatch((x,1), (x,0), 
                            coordsA=ax2.get_xaxis_transform(),
                            coordsB=ax1.get_xaxis_transform(), **kwargs)
        ax1.add_artist(p)

rectspan(1, 2, ax1, ax2, color="k", linewidth=3)
plt.show()

enter image description here

like image 76
ImportanceOfBeingErnest Avatar answered Nov 12 '22 18:11

ImportanceOfBeingErnest


There is definitely a simpler way to do it using a Rectangle patch but this is a workaround solution for the time being. The idea is to have 4 lines: 2 horizontal which are restricted to ax1 and ax2 respectively, and 2 vertical which span both ax1 and ax2. For the latter two, you use ConnectionPatch covering both the axes. To have the upper and lower y-value for the horizontal and vertical lines, you use get_ylim() function. The idea to plot vertical lines came from this official example and this answer by ImportanceOfBeingErnest

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import ConnectionPatch

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=False, sharey=False, figsize=(15,9))

x = 2 * np.pi * np.arange(1000) / 1000 
y1 = np.sin(x)
y2 = np.cos(x)

ax1.plot(x,y1)
ax2.plot(x,y2)

y_up, y_down = ax1.get_ylim(), ax2.get_ylim()

ax1.hlines(max(y_up), 1, 2, linewidth=4)
ax2.hlines(min(y_down), 1, 2, linewidth=4)

line1 = ConnectionPatch(xyA=[1,min(y_down)], xyB=[1,max(y_up)], coordsA="data", coordsB="data",
                      axesA=ax2, axesB=ax1, color="k", lw=4)
line2 = ConnectionPatch(xyA=[2,min(y_down)], xyB=[2,max(y_up)], coordsA="data", coordsB="data",
                      axesA=ax2, axesB=ax1, color="k", lw=4)

ax2.add_artist(line1)
ax2.add_artist(line2)

plt.show()

enter image description here

like image 2
Sheldore Avatar answered Nov 12 '22 18:11

Sheldore