Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pandas side-by-side stacked bar plot

I want to create a stacked bar plot of the titanic dataset. The plot needs to group by "Pclass", "Sex" and "Survived". I have managed to do this with a lot of tedious numpy manipulation to produce the normalized plot below (where "M" is male and "F" is female)enter image description here

Is there a way to do this using pandas inbuilt plotting functionality?

I have tried this:

import pandas as pd
import matplotlib.pyplot as plt
df = pd.read_csv('train.csv')
df_grouped = df.groupby(['Survived','Sex','Pclass'])['Survived'].count()
df_grouped.unstack().plot(kind='bar',stacked=True,  colormap='Blues', grid=True, figsize=(13,5));

enter image description here

Which is not what I want. Is there anyway to produce the first plot using pandas plotting? Thanks in advance

like image 625
PyRsquared Avatar asked Nov 26 '17 09:11

PyRsquared


1 Answers

The resulting bars will not neighbour each other as in your first figure, but outside of that, pandas lets you do what you want as follows:

df_g = df.groupby(['Pclass', 'Sex'])['Survived'].agg([np.mean, lambda x: 1-np.mean(x)])
df_g.columns = ['Survived', 'Died']
df_g.plot.bar(stacked=True)

enter image description here

Here, the horizontal grouping of patches is complicated by the requirement of stacking. If, for instance, we only cared about the value of "Survived", pandas could take care of it out-of-the-box.

df.groupby(['Pclass', 'Sex'])['Survived'].mean().unstack().plot.bar()

enter image description here

If an ad hoc solution suffices for post-processing the plot, doing so is also not terribly complicated:

import numpy as np
from matplotlib import ticker

df_g = df.groupby(['Pclass', 'Sex'])['Survived'].agg([np.mean, lambda x: 1-np.mean(x)])
df_g.columns = ['Survived', 'Died']
ax = df_g.plot.bar(stacked=True)

# Move back every second patch
for i in range(6):
    new_x = ax.patches[i].get_x() - (i%2)/2
    ax.patches[i].set_x(new_x)
    ax.patches[i+6].set_x(new_x)

# Update tick locations correspondingly
minor_tick_locs = [x.get_x()+1/4 for x in ax.patches[:6]]
major_tick_locs = np.array([x.get_x()+1/4 for x in ax.patches[:6]]).reshape(3, 2).mean(axis=1)
ax.set_xticks(minor_tick_locs, minor=True)
ax.set_xticks(major_tick_locs)

# Use indices from dataframe as tick labels
minor_tick_labels = df_g.index.levels[1][df_g.index.labels[1]].values
major_tick_labels = df_g.index.levels[0].values
ax.xaxis.set_ticklabels(minor_tick_labels, minor=True)
ax.xaxis.set_ticklabels(major_tick_labels)

# Remove ticks and organize tick labels to avoid overlap
ax.tick_params(axis='x', which='both', bottom='off')
ax.tick_params(axis='x', which='minor', rotation=45)
ax.tick_params(axis='x', which='major', pad=35, rotation=0)

enter image description here

like image 56
fuglede Avatar answered Nov 10 '22 00:11

fuglede