Using matplotlib, I'd like to display multiple subplots on a grid that has a different number of columns per row, where each subplot has roughly the same size, and the subplots are arranged such that they are more or less centered, like this:
It's a fairly simple matter to create a grid that has the 2, 3, 2 pattern with gridspec
, but the problem there is that gridspec
, unsurprisingly, aligns them to a grid, so the plots in the rows with 2 plots in them are wider:
Here's the code to generate that:
from matplotlib import gridspec
from matplotlib import pyplot as plt
fig = plt.figure()
arrangement = (2, 3, 2)
nrows = len(arrangement)
gs = gridspec.GridSpec(nrows, 1)
ax_specs = []
for r, ncols in enumerate(arrangement):
gs_row = gridspec.GridSpecFromSubplotSpec(1, ncols, subplot_spec=gs[r])
for col in range(ncols):
ax = plt.Subplot(fig, gs_row[col])
fig.add_subplot(ax)
for i, ax in enumerate(fig.axes):
ax.text(0.5, 0.5, "Axis: {}".format(i), fontweight='bold',
va="center", ha="center")
ax.tick_params(axis='both', bottom='off', top='off', left='off',
right='off', labelbottom='off', labelleft='off')
plt.tight_layout()
I know that I can set up a bunch of subplots and tweak their arrangement by working out the geometry of it, but I think it could get a bit complicated, so I was hoping that there might be a simpler method available.
I should note that even though I'm using a (2, 3, 2) arrangement as my example, I'd like to do this for arbitrary collections, not just this one.
The idea is usually to find the common denominator between the subplots, i.e. the largest subplot that the desired grid can be composed of, and span all subplots over several of those such that the desired layout is achieved.
Here you have 3 rows and 6 columns and each subplot spans 1 row and two columns, just that the subplots in the first row span subplot positions 1/2 and 3/4, while in the second row they span positions 0/1, 2/3, 4/5.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(3, 6)
ax1a = plt.subplot(gs[0, 1:3])
ax1b = plt.subplot(gs[0, 3:5])
ax2a = plt.subplot(gs[1, :2])
ax2b = plt.subplot(gs[1, 2:4])
ax2c = plt.subplot(gs[1, 4:])
ax3a = plt.subplot(gs[2, 1:3])
ax3b = plt.subplot(gs[2, 3:5])
for i, ax in enumerate(plt.gcf().axes):
ax.text(0.5, 0.5, "Axis: {}".format(i), fontweight='bold',
va="center", ha="center")
ax.tick_params(axis='both', bottom='off', top='off', left='off',
right='off', labelbottom='off', labelleft='off')
plt.tight_layout()
plt.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With