Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stacked bar plot by grouped data with pandas

Let's assume I have pandas dataframe which has many features and I am interested in two. I'll call them feature1 and feature2.

feature1 can have three possible values. feature2 can have two possible values.

I need bar plot grouped by feature1 and stacked by count of rows with each value of feature2. (So that there will be three stacks each with two bars).

How to achieve this?

At the moment I have

import pandas as pd
df = pd.read_csv('data.csv')
df['feature1'][df['feature2'] == 0].value_counts().plot(kind='bar',label='0')
df['feature1'][df['feature2'] == 1].value_counts().plot(kind='bar',label='1')

but that is not what I actually want because it doesn't stack them.

like image 717
justanothercoder Avatar asked Jan 21 '16 07:01

justanothercoder


3 Answers

Also, I have found another way to do this (with pandas):

df.groupby(['feature1', 'feature2']).size().unstack().plot(kind='bar', stacked=True)

Source: making a stacked barchart in pandas

like image 86
justanothercoder Avatar answered Oct 22 '22 20:10

justanothercoder


Im not sure how to do this in matplotlib (pandas default plotting library), but if you are willing to try a different plotting library, it is quite easy to do it with Bokeh.

Here is an example

import pandas as pd
from bokeh.charts import Bar, output_file, show
x = pd.DataFrame({"gender": ["m","f","m","f","m","f"],
                  "enrolments": [500,20,100,342,54,47],
                  "class": ["comp-sci", "comp-sci",
                            "psych", "psych",
                            "history", "history"]})

bar = Bar(x, values='enrolments', label='class', stack='gender',
         title="Number of students enrolled per class",
         legend='top_right',bar_width=1.0)
output_file("myPlot.html")
show(bar)

stacked bar plot

like image 27
ronrest Avatar answered Oct 22 '22 21:10

ronrest


size produces a column with a simple row count for that grouping, its what produces the values for the y axis. unstack produces the row and column information necessary for matplotlib to create the stacked bar graph.

Essentially it takes

>>> s
one  a   1.0
     b   2.0
two  a   3.0
     b   4.0

and produces:

>>> s.unstack(level=-1)
     a   b
one  1.0  2.0
two  3.0  4.0
like image 2
rtphokie Avatar answered Oct 22 '22 22:10

rtphokie