Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

prevent cutoff of imshow subplots of unequal dimension w/ sharex

I have several matrices I want to display with imshow, in subplots of the same figure. They all have the same number of columns but varying number of rows. I want to:

  1. see all of each matrix when displayed with imshow
  2. retain the aspect=1 effect of imshow
  3. use sharex in the figure

(which together imply that the heights of the subplots reflect the varying number of rows in the matrices). I tried using gridspec (via the gridspec_kw argument of plt.subplots) but the combination of sharex and aspect=1 leads to parts of the matrices getting cut off, unless I manually resize the window. Example:

import numpy as np
import matplotlib.pyplot as plt
# fake data
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)

data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']

# initialize figure
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
                        gridspec_kw=dict(height_ratios=nrows))

for ix, d in enumerate(data):
    ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

resulting image with default figure dims

Based on the number of rows in each matrix, I can guess at a ballpark figure dimension that should make them all visible, but it doesn't work:

figsize = (foo.shape[1], sum(nrows))
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
                        gridspec_kw=dict(height_ratios=nrows),
                        figsize=figsize)

for ix, d in enumerate(data):
    ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

notice how the top and bottom rows of all 3 subplots are partially cut off (it's easiest to see on the middle one) and yet there is a ton of excess whitespace at top and bottom figure margin:

resulting image with custom figure dims

Using tight_layout doesn't solve it either; it makes the subplots too big (note the gap at top/bottom of each subplot between axis spine and image):

resulting image with custom dims and tight layout

Is there any way to get imshow and sharex to work in harmony here?

like image 343
drammock Avatar asked Dec 06 '25 05:12

drammock


2 Answers

I just discovered ImageGrid, which does the trick nicely. Complete example:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
    ax = axs[ix]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

result using ImageGrid

like image 188
drammock Avatar answered Dec 07 '25 19:12

drammock


I just discovered ImageGrid, which does the trick nicely. Complete example:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
    ax = axs[ix]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

result using ImageGrid

like image 40
drammock Avatar answered Dec 07 '25 19:12

drammock