Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotting Pandas groupby groups using subplots and loop

I am trying to generate a grid of subplots based off of a Pandas groupby object. I would like each plot to be based off of two columns of data for one group of the groupby object. Fake data set:

C1,C2,C3,C4
1,12,125,25
2,13,25,25
3,15,98,25
4,12,77,25
5,15,889,25
6,13,56,25
7,12,256,25
8,12,158,25
9,13,158,25
10,15,1366,25

I have tried the following code:

import pandas as pd
import csv   
import matplotlib as mpl
import matplotlib.pyplot as plt
import math

#Path to CSV File
path = "..\\fake_data.csv"

#Read CSV into pandas DataFrame
df = pd.read_csv(path)

#GroupBy C2
grouped = df.groupby('C2')

#Figure out number of rows needed for 2 column grid plot
#Also accounts for odd number of plots
nrows = int(math.ceil(len(grouped)/2.))

#Setup Subplots
fig, axs = plt.subplots(nrows,2)

for ax in axs.flatten():
    for i,j in grouped:
        j.plot(x='C1',y='C3', ax=ax)

plt.savefig("plot.png")

But it generates 4 identical subplots with all of the data plotted on each (see example output below):

enter image description here

I would like to do something like the following to fix this:

for i,j in grouped:
    j.plot(x='C1',y='C3',ax=axs)
    next(axs)

but I get this error

AttributeError: 'numpy.ndarray' object has no attribute 'get_figure'

I will have a dynamic number of groups in the groupby object I want to plot, and many more elements than the fake data I have provided. This is why I need an elegant, dynamic solution and each group data set plotted on a separate subplot.

like image 994
fireitup Avatar asked May 21 '15 16:05

fireitup


1 Answers

Sounds like you want to iterate over the groups and the axes in parallel, so rather than having nested for loops (which iterates over all groups for each axis), you want something like this:

for (name, df), ax in zip(grouped, axs.flat):
    df.plot(x='C1',y='C3', ax=ax)

enter image description here

You have the right idea in your second code snippet, but you're getting an error because axs is an array of axes, but plot expects just a single axis. So it should also work to replace next(axs) in your example with ax = axs.next() and change the argument of plot to ax=ax.

like image 141
mcwitt Avatar answered Nov 02 '22 10:11

mcwitt