Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Prevent axis & column labels from bleeding off heatmap in matplotlib

I'm trying to plot a confusion matrix using the array below. However, when the heatmap renders, the column and axis labels are bleeding off the plot view and I can't figure out how to control this formatting. Seems like I need a way to set some padding values.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn

array = [
    [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0],
    [0,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
    [0,0,1,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
    [0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],
    [0,0,0,0,3,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],
    [0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,1,0,0,2,0,0],
    [0,0,0,0,0,0,3,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],
    [0,1,0,1,0,0,0,6,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,0,1,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,3,1,0,0,0,0,0,0,0,0,0,0,0,0],
    [0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,0,0,5,0,0,0,0,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,1,0,0,0,0,0,7,0,0,0,0,0,1,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,0,1,0,0,0,0,0,1,4,1,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,1,1,0,0,0,0],
    [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,38,0,0,0,0,0],
    [0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,0,8,0,0,0,1],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,3,0,0,0],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,0,0],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,9]]

labels = ['chatbot',
    'business_passwordreset',
    'business_losthomework',
    'oos_generic',
    'frustration',
    'social_generic',
    'business_accesscode_selfstudy',
    'business_assignmentissues',
    'end_chat',
    'bye',
    'thanks',
    'business_accesscode_lost',
    'business_accesscode_redeem',
    'business_accesscode_notreceived',
    'business_accesscode_share',
    'business_accesscode_reuse',
    'business_editorial',
    'Contact_Request',
    'business_accesscode_error',
    'business_accesscode_refund',
    'hello',
    'business_accesscode_purchase',
    'business_accesscode_troubleshoot']


df_cm = pd.DataFrame(array, index=labels, columns=labels)
sn.heatmap(df_cm, annot=True, cmap='Blues')
plt.show()

Rendered plot: confusion matrix

Everything else looks so good, but would nice to be able to read the labels! Anyone know what I'm missing?

like image 753
Marty Mitchener Avatar asked Dec 01 '25 01:12

Marty Mitchener


1 Answers

Your labels are really long, so I think your best bet is to create a large figure, and then use plt.tight_layout. As desribed in the docs:

This module provides routines to adjust subplot params so that subplots are nicely fit in the figure

# Create a large figure so your labels aren't too crowded
plt.figure(figsize=(13,7))
df_cm = pd.DataFrame(array, index=labels, columns=labels)
sn.heatmap(df_cm, annot=True, cmap='Blues')
plt.tight_layout()

plt.show()

enter image description here

like image 167
sacuL Avatar answered Dec 04 '25 23:12

sacuL



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!