Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fixing Confusion Matrix plot lines

I am trying to plot a confusion matrix as shown below

cm  = confusion_matrix(testY.argmax(axis=1), predictions.argmax(axis=1))

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=lb.classes_)
disp = disp.plot(include_values=True, cmap='viridis', ax=None, xticks_rotation='horizontal')

plt.show()

The result:

Confusion Matrix I get

As you can see, it's showing the axes of the boxes instead of outlining the boxes. I can't see the numbers outside the yellow boxes, because of the axes. I am not good with plots. So I can't find out what I need to change.

What I expect: Expected Matrix

FOUND SOLUTION

plt.tick_params(axis=u'both', which=u'both',length=0)
plt.grid(b=None)
like image 242
tasin95 Avatar asked May 22 '26 03:05

tasin95


2 Answers

Turn the grid off

E.g.,

import matplotlib.pyplot as plt
fig, _ = plt.subplots(nrows=1, figsize=(10,10))
ax = plt.subplot(1, 1, 1)
ax.grid(False)

...

disp = ConfusionMatrixDisplay(...)
_ = disp.plot(..., ax=ax, ...)
like image 108
Asaf R Avatar answered May 23 '26 15:05

Asaf R


cm  = confusion_matrix(testY.argmax(axis=1), predictions.argmax(axis=1))

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=lb.classes_)
disp = disp.plot(include_values=True, cmap='viridis', ax=None, xticks_rotation='horizontal')
plt.grid(False)
plt.show()
like image 33
Jayakrishnan Avatar answered May 23 '26 15:05

Jayakrishnan