Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matplotlib: How to combine multiple bars with lines

I have a very simple grouped-bar chart with 5 groups of 4 bars each. I used the example from the matplotlib documentation, which goes like this:

import numpy as np
import matplotlib.pyplot as plt

ind = np.arange(5)
avg_bar1 = (81191,79318,57965,60557,14793)
avg_bar2 = (26826,26615,31364,31088,55472)
avg_bar3 = (36232,38038,38615,39014,40812)
avg_bar4 = (26115,25879,25887,28326,27988)

fig, ax = plt.subplots()

rects1 = ax.bar(ind, avg_bar1, 0.15, label='bar1')
rects2 = ax.bar(ind + 0.15, avg_bar2, 0.15, label='bar2')
rects3 = ax.bar(ind + 0.30, avg_bar3, 0.15, label='bar2')
rects4 = ax.bar(ind + 0.45, avg_bar4, 0.15, label='bar2')

plt.xlabel('Distributions')
plt.ylabel('ms')
plt.xticks(ind + 0.15, ('50/50', '60/40', '70/30', '80/20', '90/10'))
plt.legend()

plt.tight_layout()
plt.show()

Problem

But some values of bars in the corresponding other groups (e.g. bar1 in group1 and bar1 in group2, etc.) do not differ that much from each other.

What I want

So I want to add lines in order to see the trend of each group more clearly. The lines should go from the top of each bar in one group to the top of the bar in the corresponding other groups.

I couldn't find anything similar on the web.

Is that possible?

like image 942
matthaeus Avatar asked Mar 12 '23 19:03

matthaeus


1 Answers

I have started to learn Matplotlib and Numpy recently and found your question quite interesting.

So, I spend time to create a combined graph of grouped-bar chart and line chart relevant to that grouped-bar chart.

Result: enter image description here

Code:

import numpy as np
import matplotlib.pyplot as plt

ind = np.arange(5)
avg_bar1 = (81191,79318,57965,60557,14793)
avg_bar2 = (26826,26615,31364,31088,55472)
avg_bar3 = (36232,38038,38615,39014,40812)
avg_bar4 = (26115,25879,25887,28326,27988)

rects1 = plt.bar(ind, avg_bar1, 0.15, color='#ff0000',label='bar1')
rects2 = plt.bar(ind + 0.15, avg_bar2, 0.15, color='#00ff00', label='bar2')
rects3 = plt.bar(ind + 0.30, avg_bar3, 0.15, color='#0000ff', label='bar3')
rects4 = plt.bar(ind + 0.45, avg_bar4, 0.15, color='#000000', label='bar4')

high_point_x = []
high_point_y = []    
for i in range(0,5):
    single_bar_group={rects1[i].get_height():rects1[i].get_x() + rects1[i].get_width()/2.0,
                      rects2[i].get_height():rects2[i].get_x() + rects2[i].get_width()/2.0,
                      rects3[i].get_height():rects3[i].get_x() + rects3[i].get_width()/2.0,
                      rects4[i].get_height():rects4[i].get_x() + rects4[i].get_width()/2.0}

    height_list = list(single_bar_group.keys())
    height_list.sort(reverse=True)
    for single_height in height_list:
        high_point_y.append(single_height)
        high_point_x.append(single_bar_group[single_height])
        break

trend_line = plt.plot(high_point_x,high_point_y,marker='o', color='#5b74a8', label='Trend Line')

plt.xlabel('Distributions')
plt.ylabel('ms')
plt.xticks(ind+0.15, ('50/50', '60/40', '70/30', '80/20', '90/10'))
plt.legend()
plt.show()
like image 149
arsho Avatar answered Mar 15 '23 05:03

arsho