Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Group labels in matplotlib barchart using Pandas MultiIndex

I have a pandas DataFrame with a MultiIndex:

group   subgroup    obs_1    obs_2
GroupA  Elem1       4        0
        Elem2       34       2
        Elem3       0        10
GroupB  Elem4       5        21

and so on. As noted in this SO question this is actually doable in matplotlib, but I'd rather (if possible) use the fact that I already know the hierarchy (thanks to the MultiIndex). Currently what's happening is that the index is shown as a tuple.

Is such a thing possible?

like image 293
Einar Avatar asked Apr 01 '14 08:04

Einar


2 Answers

If you have just two levels in the MultiIndex, I believe the following will be easier:

plt.figure()
ax = plt.gca()
DF.plot(kind='bar', ax=ax)
plt.grid(True, 'both')
minor_XT = ax.get_xaxis().get_majorticklocs()
DF['XT_V'] = minor_XT
major_XT = DF.groupby(by=DF.index.get_level_values(0)).first()['XT_V'].tolist()
DF.__delitem__('XT_V')
ax.set_xticks(minor_XT, minor=True)
ax.set_xticklabels(DF.index.get_level_values(1), minor=True)
ax.tick_params(which='major', pad=15)
_ = plt.xticks(major_XT, (DF.index.get_level_values(0)).unique(), rotation=0)

enter image description here

And a bit of involving, but more general solution (doesn't matter how many levels you have):

def cvt_MIdx_tcklab(df):
    Midx_ar = np.array(df.index.tolist())
    Blank_ar = Midx_ar.copy()
    col_idx = np.arange(Midx_ar.shape[0])
    for i in range(Midx_ar.shape[1]):
        val,idx = np.unique(Midx_ar[:, i], return_index=True)
        Blank_ar[idx, i] = val
        idx=~np.in1d(col_idx, idx)
        Blank_ar[idx, i]=''
    return map('\n'.join, np.fliplr(Blank_ar))

plt.figure()
ax = plt.gca()
DF.plot(kind='bar', ax=ax)
ax.set_xticklabels(cvt_MIdx_tcklab(DF), rotation=0)
like image 106
CT Zhu Avatar answered Oct 07 '22 11:10

CT Zhu


I think that there isn't a nice and standard way of plotting multiindex dataframes. I found the following solution by @Stein to be aesthetically pleasant. I've adapted his example to your data:

import pandas as pd
import matplotlib.pyplot as plt
from itertools import groupby
import numpy as np 
%matplotlib inline

group = ('Group_A', 'Group_B')
subgroup = ('elem1', 'elem2', 'elem3', 'elem4')
obs = ('obs_1', 'obs_2')
index = pd.MultiIndex.from_tuples([('Group_A','elem1'),('Group_A','elem2'),('Group_A','elem3'),('Group_B','elem4')],
   names=['group', 'subgroup'])
values = np.array([[4,0],[43,2],[0,10],[5,21]])
df = pd.DataFrame(index=index)
df['obs_1'] = values[:,0]
df['obs_2'] = values[:,1]

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='gray')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

ax = df.plot(kind='bar',stacked=False)
#Below 2 lines remove default labels
ax.set_xticklabels('')
ax.set_xlabel('')
label_group_bar_table(ax, df)

Which produces:

enter image description here

like image 36
Ramon Crehuet Avatar answered Oct 07 '22 11:10

Ramon Crehuet