Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matplotlib tight_layout causing RuntimeError

I have come across a problem when using plt.tight_layout() to try to tidy up a matplotlib graph with multiple subplots.

I have created 6 subplots as an example and would like to tidy up their overlapping text with tight_layout() however I get the following RuntimeError.

Traceback (most recent call last):
  File ".\test.py", line 37, in <module>
    fig.tight_layout()
  File "C:\Python34\lib\site-packages\matplotlib\figure.py", line 1606, in tight_layout
    rect=rect)
  File "C:\Python34\lib\site-packages\matplotlib\tight_layout.py", line 334, in get_tight_layout_figure
    raise RuntimeError("")
RuntimeError

My code is given here (I am using Python 3.4).

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 3*np.pi, 1000)

fig = plt.figure()


ax1 = fig.add_subplot(3, 1, 1)

ax2 = fig.add_subplot(3, 2, 3)
ax3 = fig.add_subplot(3, 2, 4)

ax4 = fig.add_subplot(3, 3, 7)
ax5 = fig.add_subplot(3, 3, 8)
ax6 = fig.add_subplot(3, 3, 9)

for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
    ax.plot(x, np.sin(x))

fig.tight_layout()

plt.show()

I suspected originally that the problem could be from having subplots of different size, however the tight layout guide seems to suggest that this should not be a problem. Any help/advice would be appreciated.

like image 993
Ffisegydd Avatar asked Mar 29 '14 17:03

Ffisegydd


1 Answers

That is definitely not a helpful error message, although there's a hint in the if clause that leads to the exception. If you use IPython, you'll get some additional context in the traceback. Here's what I saw when I tried to run your code:

    332         div_col, mod_col = divmod(max_ncols, cols)
    333         if (mod_row != 0) or (mod_col != 0):
--> 334             raise RuntimeError("")

Although you can use tight_layout with subplots of different size, they have to be laid out on a regular grid. If you look closely at the documentation, it's actually using the plt.subplot2grid function to set up the plot that's most closely related to what you're trying to do.

So, to get exactly what you want you'll have to lay it out on a 3x6 grid:

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
fig = plt.figure()

# Top row
ax1 = plt.subplot2grid((3, 6), (0, 0), colspan=6)

# Middle row
ax2 = plt.subplot2grid((3, 6), (1, 0), colspan=3)
ax3 = plt.subplot2grid((3, 6), (1, 3), colspan=3)

# Bottom row
ax4 = plt.subplot2grid((3, 6), (2, 0), colspan=2)
ax5 = plt.subplot2grid((3, 6), (2, 2), colspan=2)
ax6 = plt.subplot2grid((3, 6), (2, 4), colspan=2)

# Plot a sin wave
for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
    ax.plot(x, np.sin(x))

# Make the grid nice
fig.tight_layout()

enter image description here

The first argument gives the grid dimensions, the second gives the top left grid position for the subplot, and the rowspan and colspan arguments say how many points in the grid each subplot should extend over.

like image 182
mwaskom Avatar answered Oct 21 '22 15:10

mwaskom