Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Confusion Matrix with number of classified/misclassified instances on it (Python/Matplotlib)

I am plotting a confusion matrix with matplotlib with the following code:

from numpy import *
import matplotlib.pyplot as plt
from pylab import *

conf_arr = [[33,2,0,0,0,0,0,0,0,1,3], [3,31,0,0,0,0,0,0,0,0,0], [0,4,41,0,0,0,0,0,0,0,1], [0,1,0,30,0,6,0,0,0,0,1], [0,0,0,0,38,10,0,0,0,0,0], [0,0,0,3,1,39,0,0,0,0,4], [0,2,2,0,4,1,31,0,0,0,2], [0,1,0,0,0,0,0,36,0,2,0], [0,0,0,0,0,0,1,5,37,5,1], [3,0,0,0,0,0,0,0,0,39,0], [0,0,0,0,0,0,0,0,0,0,38] ]

norm_conf = []
for i in conf_arr:
    a = 0
    tmp_arr = []
    a = sum(i,0)
    for j in i:
        tmp_arr.append(float(j)/float(a))
    norm_conf.append(tmp_arr)

plt.clf()
fig = plt.figure()
ax = fig.add_subplot(111)
res = ax.imshow(array(norm_conf), cmap=cm.jet, interpolation='nearest')
cb = fig.colorbar(res)
savefig("confmat.png", format="png")

But I want to the confusion matrix to show the numbers on it like this graphic (the right one). How can I plot the conf_arr on the graphic?

confusion matrix

like image 546
Pinkie Avatar asked May 24 '10 14:05

Pinkie


1 Answers

You can use text to put arbitrary text in your plot. For example, inserting the following lines into your code will write the numbers (note the first and last lines are from your code to show you where to insert my lines):

res = ax.imshow(array(norm_conf), cmap=cm.jet, interpolation='nearest')
for i, cas in enumerate(conf_arr):
    for j, c in enumerate(cas):
        if c>0:
            plt.text(j-.2, i+.2, c, fontsize=14)
cb = fig.colorbar(res)

matrix with numbers

like image 184
tom10 Avatar answered Sep 20 '22 15:09

tom10