Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plot two levels of x_ticklabels on a pandas multi-index dataframe [duplicate]

I have a multi-index dataframe where the index has been dervied from dates. It comprises year and quarter values.

What I want to achive is a plot with two sets of tick labels on the x axis. The minor tick label should represent the quarter values (1 to 4) and the major ticklabels the year values. However, I do not want all of the year ticklabels displayed, only unique years for each of the four quarters.

This is straightforward to represent in an excel graph, here is an example of what I am trying to reproduce. enter image description here

Here is a sample from my dataset.

serotype_df = pd.DataFrame({'13v': {(2002, 1): 5,
  (2002, 2): 9,
  (2002, 3): 23,
  (2002, 4): 11,
  (2003, 1): 1,
  (2003, 2): 12,
  (2003, 3): 22,
  (2003, 4): 15,
  (2004, 1): 10,
  (2004, 2): 11,
  (2004, 3): 30,
  (2004, 4): 11,
  (2005, 1): 9,
  (2005, 2): 20,
  (2005, 3): 20,
  (2005, 4): 7},
 '23v': {(2002, 1): 1,
  (2002, 2): 8,
  (2002, 3): 18,
  (2002, 4): 5,
  (2003, 1): 5,
  (2003, 2): 16,
  (2003, 3): 13,
  (2003, 4): 7,
  (2004, 1): 4,
  (2004, 2): 4,
  (2004, 3): 20,
  (2004, 4): 5,
  (2005, 1): 4,
  (2005, 2): 5,
  (2005, 3): 10,
  (2005, 4): 5},
 '7v': {(2002, 1): 30,
  (2002, 2): 75,
  (2002, 3): 148,
  (2002, 4): 68,
  (2003, 1): 26,
  (2003, 2): 75,
  (2003, 3): 147,
  (2003, 4): 67,
  (2004, 1): 32,
  (2004, 2): 84,
  (2004, 3): 151,
  (2004, 4): 62,
  (2005, 1): 21,
  (2005, 2): 49,
  (2005, 3): 81,
  (2005, 4): 26},
 'Non-typed': {(2002, 1): 1,
  (2002, 2): 2,
  (2002, 3): 4,
  (2002, 4): 4,
  (2003, 1): 3,
  (2003, 2): 5,
  (2003, 3): 9,
  (2003, 4): 8,
  (2004, 1): 1,
  (2004, 2): 4,
  (2004, 3): 6,
  (2004, 4): 4,
  (2005, 1): 4,
  (2005, 2): 10,
  (2005, 3): 7,
  (2005, 4): 11},
 'Non-vaccine': {(2002, 1): 2,
  (2002, 2): 7,
  (2002, 3): 10,
  (2002, 4): 6,
  (2003, 1): 4,
  (2003, 2): 5,
  (2003, 3): 13,
  (2003, 4): 8,
  (2004, 1): 2,
  (2004, 2): 4,
  (2004, 3): 19,
  (2004, 4): 8,
  (2005, 1): 4,
  (2005, 2): 3,
  (2005, 3): 15,
  (2005, 4): 5}})

I have tried to use some code from a different SO example. Here is the code I tried.

import pandas as pd
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(14,8), dpi=200) 
ax = fig.add_subplot(111)
ax1 = ax.twiny()

serotype_df.plot(kind='bar', ax=ax, stacked='True');


trunc = lambda x: x.strip("()").split(" ")[1]
tl = [ trunc(t.get_text()) for t in ax.get_xticklabels()]
ax.set_xticklabels(tl,rotation=0);


serotype_df.plot(kind='bar', ax=ax1, stacked='True');

trunc0 = lambda x: x.strip("()").split(", ")[0]
tl = [ trunc0(t.get_text()) for t in ax1.get_xticklabels()]
ax1.set_xticklabels(tl);

I have the quarter xlabels exactly where I want them. I just can't seem to get the unique year values.

Any help is greatly appreciated.

like image 818
John Avatar asked Aug 13 '18 09:08

John


1 Answers

Try the following code. It's achieved by creating a subplot for each level[0] index in your case year and using that as the x_label. And for each subplot we plot the data.

def plot_function(x, ax):
    ax = graph[x]
    ax.set_xlabel(x, weight='bold')
    return serotype_df.xs(x).plot(kind='bar', stacked='True', ax=ax, legend=False)

n_subplots = len(serotype_df.index.levels[0])
fig, axes = plt.subplots(nrows=1, ncols=n_subplots, sharey=True, figsize=(14, 8))  # width, height

graph = dict(zip(serotype_df.index.levels[0], axes))
plots = list(map(lambda x: plot_function(x, graph[x]), graph))
ax.tick_params(axis='both', which='both', length=0)
fig.subplots_adjust(wspace=0)

plt.legend()
plt.show()

If you're not making much changes to each subplot you can always do the following:

plots = list(map(lambda x: serotype_df.xs(x).plot(kind='bar', stacked='True', ax=graph[x], legend=False).set_xlabel(x, weight='bold'), graph))

That way you don't have to create or use the plot_function

enter image description here

like image 60
gyx-hh Avatar answered Nov 15 '22 18:11

gyx-hh