Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stuffing a pandas DataFrame.plot into a matplotlib subplot

My brain hurts

I have some code that produces 33 graphics in one long column

#fig,axes = plt.subplots(nrows=11,ncols=3,figsize=(18,50))
accountList =  list(set(training.account))
for i in range(1,len(accountList)):
    training[training.account == accountList[i]].plot(kind='scatter',x='date_int',y='rate',title=accountList[i])
#axes[0].set_ylabel('Success Rate')

I'd like to get each of those plots into the figure that I have commented out above, but all my attempts are failing. I tried putting ax=i into the plot command and I get 'numpy.ndarray' object has no attribute 'get_figure'. Also, when I scale back and do this with one single plot in a one by one figure, my x and y scales both go to heck. I feel like I'm close to the answer, but I need a little push. Thanks.

like image 745
Kevin Thompson Avatar asked Feb 23 '14 00:02

Kevin Thompson


People also ask

How do I create a subplot in pandas?

You can plot multiple subplots of multiple pandas data frames using matplotlib with a simple trick of making a list of all data frame. Then using the for loop for plotting subplots. Using this code you can plot subplots in any configuration. You need to define the number of rows nrow and the number of columns ncol .

Can you plot pandas DataFrame with Matplotlib?

Matplotlib is an amazing python library which can be used to plot pandas dataframe.

What does subplots () do in Matplotlib?

Subplots mean groups of axes that can exist in a single matplotlib figure. subplots() function in the matplotlib library, helps in creating multiple layouts of subplots. It provides control over all the individual plots that are created.


1 Answers

The axes handles that subplots returns vary according to the number of subplots requested:

  • for (1x1) you get a single handle,
  • for (n x 1 or 1 x n) you get a 1d array of handles,
  • for (m x n) you get a 2d array of handles.

It appears that your problem arises from the change in interface from the 2nd to 3rd cases (i.e. 1d to 2d axis array). The following snippets can help if you don't know ahead of time what the array shape will be.

I have found numpy's unravel_index useful for iterating over the axes, e.g.:

ncol = 3 # pick one dimension
nrow = (len(accountList)+ ncol-1) / ncol # make sure enough subplots
fig, ax = plt.subplots(nrows=nrow, ncols=ncol) # create the axes

for i in xrange(len(accountList)):   # go over a linear list of data
  ix = np.unravel_index(i, ax.shape) # compute an appropriate index (1d or 2d)

  accountList[i].plot( ..., ax=ax[ix])   # pandas method plot
  ax[ix].plot(...)   # or direct axis object method plot (/scatter/bar/...)

You can also reshape the returned array so that it is linear (as I used in this answer):

for a in ax.reshape(-1):
    a.plot(...)

As noted in the linked solution, axs needs a bit of massaging if you might have 1x1 subplots (and then receive a single axes handle; axs = np.array(axs) is enough).


And after reading the docs more carefully (oops), setting squeeze=False forces subplots to return a 2d matrix regardless of the choices of ncols/nrows. (squeeze defaults to True).

If you do this, you can either iterate over two dimensions (if it is natural for your data), or use either of the above approaches to iterate over your data linearly and computing a 2d index into ax.

like image 139
Bonlenfum Avatar answered Oct 15 '22 23:10

Bonlenfum