Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Embedding several inset axes in another axis using matplotlib

Is it possible to embed a changing number of plots in a matplotlib axis? For example, the inset_axes method is used to place inset axes inside parent axes:

enter image description here

However, I have several rows of plots and I want to include some inset axes inside the last axis object of each row.

fig, ax = plt.subplots(2,4, figsize=(15,15))
for i in range(2):
    ax[i][0].plot(np.random.random(40))
    ax[i][2].plot(np.random.random(40))
    ax[i][3].plot(np.random.random(40))

    # number of inset axes
    number_inset = 5
    for j in range(number_inset):
        ax[i][4].plot(np.random.random(40))

enter image description here

Here instead of the 5 plots drawn in the last column, I want several inset axes containing a plot. Something like this:

enter image description here

The reason for this is that every row refers to a different item to be plotted and the last column is supposed to contain the components of such item. Is there a way to do this in matplotlib or maybe an alternative way to visualize this?

Thanks

like image 456
Robert Smith Avatar asked Jan 24 '26 10:01

Robert Smith


1 Answers

As @hitzg mentioned, the most common way to accomplish something like this is to use GridSpec. GridSpec creates an imaginary grid object that you can slice to produce subplots. It's an easy way to align fairly complex layouts that you want to follow a regular grid.

However, it may not be immediately obvious how to use it in this case. You'll need to create a GridSpec with numrows * numinsets rows by numcols columns and then create the "main" axes by slicing it with intervals of numinsets.

In the example below (2 rows, 4 columns, 3 insets), we'd slice by gs[:3, 0] to get the upper left "main" axes, gs[3:, 0] to get the lower left "main" axes, gs[:3, 1] to get the next upper axes, etc. For the insets, each one is gs[i, -1].

As a complete example:

import numpy as np
import matplotlib.pyplot as plt

def build_axes_with_insets(numrows, numcols, numinsets, **kwargs):
    """
    Makes a *numrows* x *numcols* grid of subplots with *numinsets* subplots
    embedded as "sub-rows" in the last column of each row.

    Returns a figure object and a *numrows* x *numcols* object ndarray where
    all but the last column consists of axes objects, and the last column is a
    *numinsets* length object ndarray of axes objects.
    """
    fig = plt.figure(**kwargs)
    gs = plt.GridSpec(numrows*numinsets, numcols)

    axes = np.empty([numrows, numcols], dtype=object)
    for i in range(numrows):
        # Add "main" axes...
        for j in range(numcols - 1):
            axes[i, j] = fig.add_subplot(gs[i*numinsets:(i+1)*numinsets, j])

        # Add inset axes...
        for k in range(numinsets):
            m = k + i * numinsets
            axes[i, -1][k] = fig.add_subplot(gs[m, -1])

    return fig, axes

def plot(axes):
    """Recursive plotting function just to put something on each axes."""
    for ax in axes.flat:
        data = np.random.normal(0, 1, 100).cumsum()
        try:
            ax.plot(data)
            ax.set(xticklabels=[], yticklabels=[])
        except AttributeError:
            plot(ax)

fig, axes = build_axes_with_insets(2, 4, 3, figsize=(12, 6))
plot(axes)
fig.tight_layout()
plt.show()

enter image description here

like image 130
Joe Kington Avatar answered Jan 25 '26 23:01

Joe Kington



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!