Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can't label multiple rows of sns.catplot()

Here's my source code:

plot = sns.catplot(x='Year',
                   y='Graduation Rate',
                   col='Group',
                   hue='Campus',
                   kind='bar',
                   col_wrap=4,
                   data=mbk_grad.sort_values(['Group', 'Campus']))

for i in np.arange(2):
    for j in np.arange(4):
        ax = plot.facet_axis(i,j) 
        for p in ax.patches:
            if str(p.get_height()) != 'nan':
                ax.text(p.get_x() + 0.06, p.get_height() * .8, '{0:.2f}%'.format(p.get_height()), color='white', rotation='vertical', size='large')

plt.show()

The output is the following:

enter image description here

How do I get the rows after the first labeled like the first row is? Why isn't my nested for-loop working?

like image 466
ekazubuike Avatar asked Mar 15 '26 23:03

ekazubuike


1 Answers

If you look at plot.axes.shape, you will see that the array of axes is not (2,4) as you are expected but (8,) a 1D array. This is because you are using col_wrap, and not defining a layout grid.

plot = sns.catplot(x='Year',
                   y='Graduation Rate',
                   col='Group',
                   hue='Campus',
                   kind='bar',
                   col_wrap=4,
                   data=df_sum.sort_values(['Group', 'Campus']))

for i in np.arange(8):
#     for j in np.arange(4):
        ax1 = plot.facet_axis(0,i)
        for p in ax1.patches:
            if str(p.get_height()) != 'nan':
                ax1.text(p.get_x() + 0.06, p.get_height() * .8, '{0:.2f}%'.format(p.get_height()), color='white', rotation='vertical', size='large')

Output:

enter image description here

PS. I attended Booker T. Washington (HSEP). Where is my HISD school here?

like image 73
Scott Boston Avatar answered Mar 18 '26 12:03

Scott Boston



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!