Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Error when looping to produce subplots

I have a question about an error I receive when looping to plot multiple subplots from a data frame.

My data frame has many columns, of which I loop over to have a subplot of each column.

This is my code

 def plot(df):
    channels=[]
    for i in df:
        channels.append(i)

    fig, ax = plt.subplots(len(channels), sharex=True, figsize=(50,100))

    plot=0    
    for j in df: 

        ax[plot].plot(df["%s" % j])
        ax[plot].set_xlabel('%s' % j)
        plot=plot+1

    plt.tight_layout()
    plt.show() 

I get the plot produced fine, but also an empty frame and the error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\AClayton\WinPython-64bit-2.7.5.3\python-2.7.5.amd64\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 538, in runfile
    execfile(filename, namespace)
  File "C:/Users/AClayton/Desktop/Data/TS.py", line 67, in <module>
    plot(all_data)
  File "C:/Users/AClayton/Desktop/Data/TS.py", line 49, in plot
    ax[plot].plot(reader["%s" % j])
TypeError: 'AxesSubplot' object does not support indexing

I can't see where this error comes from if the first plot is produced fine, or why the second figure is produced?

Thanks for any insight

like image 971
Ashleigh Clayton Avatar asked Nov 13 '13 12:11

Ashleigh Clayton


People also ask

How do you plot a subplot in a loop?

Create a figure and a set of subplots with number of rows = 3 and number of columns = 2. Make a function to iterate the columns of each row and plot the x data points using plot() method at each column index. Iterate rows (Step 2) and create random x data points and call iterate_columns() function (Step 3).

What is the difference between PLT figure and PLT subplots?

figure() − Creates a new figure or activates an existing figure. plt. subplots() − Creates a figure and a set of subplots.

How do I increase the gap between subplots?

We can use the plt. subplots_adjust() method to change the space between Matplotlib subplots. The parameters wspace and hspace specify the space reserved between Matplotlib subplots. They are the fractions of axis width and height, respectively.


1 Answers

If you plot multiple subplots, the plt.subplots() returns the axes in an array, that array allows indexing like you do with ax[plot]. When only 1 subplot is created, by default it returns the axes itself, not the axes within an array.

So your error occurs when len(channels) equals 1. You can suppress this behavior by setting squeeze=False in the .subplots() command. This forces it to always return a 'Rows x Cols' sized array with the axes, even if its a single one.

So:

 def plot(df):
    channels=[]
    for i in df:
        channels.append(i)

    fig, ax = plt.subplots(len(channels),1, sharex=True, figsize=(50,100), squeeze=False)

    plot=0    
    for j in df: 

        ax[plot,0].plot(df["%s" % j])
        ax[plot,0].set_xlabel('%s' % j)
        plot=plot+1

    plt.tight_layout()
    plt.show() 

By adding the squeeze keyword you always get a 2D array in return, so the indexing for a subplot changes to ax[plot,0]. I have also specifically added the amount of columns (1 in this case).

like image 67
Rutger Kassies Avatar answered Oct 04 '22 06:10

Rutger Kassies