Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plot bar chart in multiple subplot rows with Pandas

I have a simple long-form dataset I would like to generate bar charts from. The dataframe looks like this:

data = {'Year':[2019,2019,2019,2020,2020,2020,2021,2021,2021],
        'Month_diff':[0,1,2,0,1,2,0,1,2],
        'data': [12,10,13,16,12,18,19,45,34]}
df = pd.DataFrame(data)

I would like to plot a bar chart that has 3 rows, each for 2019, 2020 and 2021. X axis being month_diff and data goes on Y axis. How do I do this?

If the data was in different columns, then I could have just used this code:

df.plot(x="X", y=["A", "B", "C"], kind="bar")

But my data is in a single column and ideally, I'd like to have different row for each year.

like image 582
Patthebug Avatar asked Nov 15 '21 23:11

Patthebug


1 Answers

1. seaborn.catplot

The simplest option for a long-form dataframe is the seaborn.catplot wrapper, as Johan said:

import seaborn as sns
sns.catplot(data=df, x='Month_diff', y='data', row='Year',
            kind='bar', height=2, aspect=4)


2. pivot + DataFrame.plot

Without seaborn:

  • pivot from long-form to wide-form (1 year per column)
  • use DataFrame.plot with subplots=True to put each year into its own subplot (and optionally sharey=True)
(df.pivot(index='Month_diff', columns='Year', values='data')
   .plot.bar(subplots=True, sharey=True, legend=False))
plt.tight_layout()

Note that if you prefer a single grouped bar chart (which you alluded to at the end), you can just leave out the subplots param:

df.pivot(index='Month_diff', columns='Year', values='data').plot.bar()


3. DataFrame.groupby + subplots

You can also iterate the df.groupby('Year') object:

  • Create a subplots grid of axes based on the number of groups (years)
  • Plot each group (year) onto its own subplot row
groups = df.groupby('Year')
fig, axs = plt.subplots(nrows=len(groups), ncols=1, sharex=True, sharey=True)

for (name, group), ax in zip(groups, axs):
    group.plot.bar(x='Month_diff', y='data', legend=False, ax=ax)
    ax.set_title(name)

fig.supylabel('data')
plt.tight_layout()

like image 150
tdy Avatar answered Oct 26 '22 23:10

tdy