Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Arrange matplotlib subplots in skewed grid

Using matplotlib, I'd like to display multiple subplots on a grid that has a different number of columns per row, where each subplot has roughly the same size, and the subplots are arranged such that they are more or less centered, like this:

Grid of axes in pattern (2, 3, 2)

It's a fairly simple matter to create a grid that has the 2, 3, 2 pattern with gridspec, but the problem there is that gridspec, unsurprisingly, aligns them to a grid, so the plots in the rows with 2 plots in them are wider:

Grid aligned with gridspec

Here's the code to generate that:

from matplotlib import gridspec
from matplotlib import pyplot as plt

fig = plt.figure()

arrangement = (2, 3, 2)
nrows = len(arrangement)

gs = gridspec.GridSpec(nrows, 1)
ax_specs = []
for r, ncols in enumerate(arrangement):
    gs_row = gridspec.GridSpecFromSubplotSpec(1, ncols, subplot_spec=gs[r])
    for col in range(ncols):
        ax = plt.Subplot(fig, gs_row[col])
        fig.add_subplot(ax)

for i, ax in enumerate(fig.axes):
    ax.text(0.5, 0.5, "Axis: {}".format(i), fontweight='bold',
            va="center", ha="center")
    ax.tick_params(axis='both', bottom='off', top='off', left='off',
                   right='off', labelbottom='off', labelleft='off')

plt.tight_layout()

I know that I can set up a bunch of subplots and tweak their arrangement by working out the geometry of it, but I think it could get a bit complicated, so I was hoping that there might be a simpler method available.

I should note that even though I'm using a (2, 3, 2) arrangement as my example, I'd like to do this for arbitrary collections, not just this one.

like image 591
Paul Avatar asked Mar 08 '23 10:03

Paul


1 Answers

The idea is usually to find the common denominator between the subplots, i.e. the largest subplot that the desired grid can be composed of, and span all subplots over several of those such that the desired layout is achieved.

enter image description here

Here you have 3 rows and 6 columns and each subplot spans 1 row and two columns, just that the subplots in the first row span subplot positions 1/2 and 3/4, while in the second row they span positions 0/1, 2/3, 4/5.

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

gs = gridspec.GridSpec(3, 6)
ax1a = plt.subplot(gs[0, 1:3])
ax1b = plt.subplot(gs[0, 3:5])
ax2a = plt.subplot(gs[1, :2])
ax2b = plt.subplot(gs[1, 2:4])
ax2c = plt.subplot(gs[1, 4:])
ax3a = plt.subplot(gs[2, 1:3])
ax3b = plt.subplot(gs[2, 3:5])


for i, ax in enumerate(plt.gcf().axes):
    ax.text(0.5, 0.5, "Axis: {}".format(i), fontweight='bold',
            va="center", ha="center")
    ax.tick_params(axis='both', bottom='off', top='off', left='off',
                   right='off', labelbottom='off', labelleft='off')

plt.tight_layout()

plt.show()

enter image description here

like image 87
ImportanceOfBeingErnest Avatar answered Mar 11 '23 06:03

ImportanceOfBeingErnest