Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to sharex when using subplot2grid

I'm a Matlab user recently converted to Python. Most of the Python skills I manage on my own, but with plotting I have hit the wall and need some help.

This is what I'm trying to do...

I need to make a figure that consists of 3 subplots with following properties:

  • subplot layout is 311, 312, 313
  • the height of 312 and 313 is approximately half of the 311
  • all subplots share common X axis
  • the space between the subplots is 0 (they touch each other at X axis)

By the way I know how to make all this, only not in a single figure. That is the problem I'm facing now.

For example, this is my ideal subplot layout:

import numpy as np import matplotlib.pyplot as plt  t = np.arange(0.0, 2.0, 0.01)  s1 = np.sin(2*np.pi*t) s2 = np.exp(-t) s3 = s1*s2  fig = plt.figure() ax1 = plt.subplot2grid((4,3), (0,0), colspan=3, rowspan=2) ax2 = plt.subplot2grid((4,3), (2,0), colspan=3) ax3 = plt.subplot2grid((4,3), (3,0), colspan=3)  ax1.plot(t,s1) ax2.plot(t[:150],s2[:150]) ax3.plot(t[30:],s3[30:])  plt.tight_layout()  plt.show() 

Notice how the x axis of different subplots is misaligned. I do not know how to align the x axis in this figure, but if I do something like this:

import numpy as np import matplotlib.pyplot as plt  t = np.arange(0.0, 2.0, 0.01)  s1 = np.sin(2*np.pi*t) s2 = np.exp(-t) s3 = s1*s2  fig2, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, sharex=True)  ax1.plot(t,s1) ax2.plot(t[:150],s2[:150]) ax3.plot(t[30:],s3[30:])  plt.tight_layout()  plt.show() 

Now the x axis is aligned between the subplots, but all subplots are the same size (which is not what I want)

Furthermore, I would like that the subplots are touching at x axis like this:

import numpy as np import matplotlib.pyplot as plt  t = np.arange(0.0, 2.0, 0.01)  s1 = np.sin(2*np.pi*t) s2 = np.exp(-t) s3 = s1*s2  fig1 = plt.figure() plt.subplots_adjust(hspace=0)  ax1 = plt.subplot(311) ax2 = plt.subplot(312, sharex=ax1) ax3 = plt.subplot(313, sharex=ax1)  ax1.plot(t,s1) ax2.plot(t[:150],s2[:150]) ax3.plot(t[30:],s3[30:])  xticklabels = ax1.get_xticklabels()+ax2.get_xticklabels() plt.setp(xticklabels, visible=False)  plt.show() 

So to rephrase my question:

I would like to use

plt.subplot2grid(..., colspan=3, rowspan=2) plt.subplots(..., sharex=True) plt.subplots_adjust(hspace=0) 

and

plt.tight_layout() 

together in the same figure. How to do that?

like image 334
Boris L. Avatar asked Feb 09 '14 16:02

Boris L.


People also ask

What is sharex false?

If sharex is set to False or none , each x-axis of a subplot will be independent. If it is set to row , each subplot row will share an x-axis. If it is set to col , each subplot column will share an x-axis. sharey.

How do you share Y axis?

You can share the x or y axis limits for one axis with another by passing an Axes instance as a sharex or sharey keyword argument.

What is sharex sharey Matplotlib?

Controls sharing of properties among x ( sharex ) or y ( sharey ) axes: True or 'all': x- or y-axis will be shared among all subplots. False or 'none': each subplot x- or y-axis will be independent. 'row': each subplot row will share an x- or y-axis.

How do I reduce the space between subplots in Matplotlib?

To remove the space between subplots in matplotlib, we can use GridSpec(3, 3) class and add axes as a subplot arrangement.


2 Answers

Just specify sharex=ax1 when creating your second and third subplots.

import numpy as np import matplotlib.pyplot as plt  t = np.arange(0.0, 2.0, 0.01)  s1 = np.sin(2*np.pi*t) s2 = np.exp(-t) s3 = s1*s2  fig = plt.figure() ax1 = plt.subplot2grid((4,3), (0,0), colspan=3, rowspan=2) ax2 = plt.subplot2grid((4,3), (2,0), colspan=3, sharex=ax1) ax3 = plt.subplot2grid((4,3), (3,0), colspan=3, sharex=ax1)  ax1.plot(t,s1) ax2.plot(t[:150],s2[:150]) ax3.plot(t[30:],s3[30:])  fig.subplots_adjust(hspace=0)    for ax in [ax1, ax2]:     plt.setp(ax.get_xticklabels(), visible=False)     # The y-ticks will overlap with "hspace=0", so we'll hide the bottom tick     ax.set_yticks(ax.get_yticks()[1:])    plt.show() 

enter image description here

If you still what to use fig.tight_layout(), you'll need to call it before fig.subplots_adjust(hspace=0). The reason for this is that tight_layout works by automatically calculating parameters for subplots_adjust and then calling it, so if subplots_adjust is manually called first, anything in the first call to it will be overridden by tight_layout.

E.g.

fig.tight_layout() fig.subplots_adjust(hspace=0) 
like image 57
Joe Kington Avatar answered Oct 07 '22 18:10

Joe Kington


A possible solution is to manually create the axis using the add_axis method like shown here:

import numpy as np import matplotlib.pyplot as plt  t = np.arange(0.0, 2.0, 0.01)  s1 = np.sin(2*np.pi*t) s2 = np.exp(-t) s3 = s1*s2  left, width = 0.1, 0.8 rect1 = [left, 0.5, width, 0.4] rect2 = [left, 0.3, width, 0.15] rect3 = [left, 0.1, width, 0.15]  fig = plt.figure() ax1 = fig.add_axes(rect1)  #left, bottom, width, height ax2 = fig.add_axes(rect2, sharex=ax1) ax3  = fig.add_axes(rect3, sharex=ax1)  ax1.plot(t,s1) ax2.plot(t[:150],s2[:150]) ax3.plot(t[30:],s3[30:])  # hide labels for label1,label2 in zip(ax1.get_xticklabels(),ax2.get_xticklabels()):     label1.set_visible(False)     label2.set_visible(False)  plt.show() 

But this way you cannot use tight_layout as you explicitly define the size of each axis.

like image 22
Jakob Avatar answered Oct 07 '22 16:10

Jakob