I have several matrices I want to display with imshow, in subplots of the same figure. They all have the same number of columns but varying number of rows. I want to:
imshowaspect=1 effect of imshowsharex in the figure(which together imply that the heights of the subplots reflect the varying number of rows in the matrices). I tried using gridspec (via the gridspec_kw argument of plt.subplots) but the combination of sharex and aspect=1 leads to parts of the matrices getting cut off, unless I manually resize the window. Example:
import numpy as np
import matplotlib.pyplot as plt
# fake data
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
# initialize figure
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
gridspec_kw=dict(height_ratios=nrows))
for ix, d in enumerate(data):
ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)

Based on the number of rows in each matrix, I can guess at a ballpark figure dimension that should make them all visible, but it doesn't work:
figsize = (foo.shape[1], sum(nrows))
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
gridspec_kw=dict(height_ratios=nrows),
figsize=figsize)
for ix, d in enumerate(data):
ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)
notice how the top and bottom rows of all 3 subplots are partially cut off (it's easiest to see on the middle one) and yet there is a ton of excess whitespace at top and bottom figure margin:

Using tight_layout doesn't solve it either; it makes the subplots too big (note the gap at top/bottom of each subplot between axis spine and image):

Is there any way to get imshow and sharex to work in harmony here?
I just discovered ImageGrid, which does the trick nicely. Complete example:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
ax = axs[ix]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)

I just discovered ImageGrid, which does the trick nicely. Complete example:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
ax = axs[ix]
_ = ax.imshow(d)
_ = ax.yaxis.set_ticks(range(d.shape[0]))
_ = ax.xaxis.set_ticks(range(d.shape[1]))
_ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
_ = ax.xaxis.set_ticklabels(col_labels)

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With