Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plot several barplots using matplotlib and subplot

I want to plot several barplots using matplotlib library (or other libraries if possible), and place each figure in its place using subplot.

I am also using groupby to group by in each category and sum the values. Then I just want to show three columns (Num1, Num2, Num3):

#Build subplot with three rows and two columns
fig, axes = plt.subplots(figsize=(12, 8) , nrows = 3, ncols = 2)
fig.tight_layout()

#five categorical columns and three numerical columns of interest
for i, category in enumerate(['Cat1', 'Cat2', 'Cat3', 'Cat4', 'Cat5']):   
    ax = fig.add_subplot(3,2,i+1)
    data.groupby(category).sum()[['Num1','Num2','Num3']].plot.bar(rot=0)
    plt.xticks(rotation = 90)

What I get are six empty plots arranged in 3rows and 2cols, followed by 5 correct plots arranged in one column one after another. An example of a plots is seen in the photo.

Thanks for your helps and suggestions.

Figure Hereeee

like image 581
mah65 Avatar asked Feb 28 '26 00:02

mah65


2 Answers

When you create a figure using fig, axes = plt.subplots(figsize=(12, 8) , nrows = 3, ncols = 2), you already have initialized all of the subplots with the nrows and ncols keywords. axes is a list you can iterate over during the for loop.

I think everything should work fine if you change:

ax = fig.add_subplot(3,2,i+1)

to:

ax = axes[i]

All together:

fig, axes = plt.subplots(figsize=(12, 8) , nrows = 3, ncols = 2)
fig.tight_layout()

#five categorical columns and three numerical columns of interest
for i, category in enumerate(['Cat1', 'Cat2', 'Cat3', 'Cat4', 'Cat5']):   
    ax = axes[i]
    data.groupby(category).sum()[['Num1','Num2','Num3']].plot.bar(rot=0,ax=ax)
    ax.xticks(rotation = 90)
like image 52
dubbbdan Avatar answered Mar 01 '26 12:03

dubbbdan


Thanks for your helps all my friend.

The final code that worked:

#Build subplot with three rows and two columns
nrows = 3
ncols = 2
fig, axes = plt.subplots(figsize=(12, 16) , nrows = nrows, ncols = ncols)
fig.tight_layout()

#five categorical columns and three numerical columns of interest
for i, category in enumerate(['Cat1', 'Cat2', 'Cat3', 'Cat4', 'Cat5']):   
    ax = axes[i%nrows][i%ncols]
    data.groupby(category).sum()[['Num1','Num2','Num3']].plot.bar(rot=0, ax=ax)

#Rotating xticks for all
for ax in fig.axes:
    plt.sca(ax)
    plt.xticks(rotation=90)
    fig.tight_layout()
like image 25
mah65 Avatar answered Mar 01 '26 13:03

mah65