Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matplotlib matshow with many string labels

Today I tried to plot confusion matrix from my classification model.

After search in some pages, I found that matshow from pyplot can help me.

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(cm)
    plt.title(title)
    fig.colorbar(cax)
    if labels:
        ax.set_xticklabels([''] + labels)
        ax.set_yticklabels([''] + labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

It works well if I have few labels

y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a']
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)

enter image description here

But if I have many labels, some labels don't show correctly

y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)

enter image description here

My question is how can I display ALL labels in matshow plot? I tried something like fontdict but it still not working

like image 666
Vu Anh Avatar asked Jan 14 '16 03:01

Vu Anh


2 Answers

You can control the frequency of ticks using the matplotlib.ticker module.

In this case, you want to set a tick every multiple of 1, so we can use a MultipleLocator

Add these two lines before you call plt.show():

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

And it will produce a tick and label for each letter in your y_true and y_pred.

I also changed your matshow call to make use of the colormap you specify in the function call:

cax = ax.matshow(cm,cmap=cmap)

enter image description here

For completeness, your entire function will look like this:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import matplotlib.ticker as ticker

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call
    cax = ax.matshow(cm,cmap=cmap)
    plt.title(title)
    fig.colorbar(cax)
    if labels:
        ax.set_xticklabels([''] + labels)
        ax.set_yticklabels([''] + labels)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('confusionmatrix.png')
like image 104
tmdavison Avatar answered Oct 26 '22 16:10

tmdavison


You can use xticks method to specify labels. Your function will look like this (modifying the function from the above answer):

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call
    cax = ax.matshow(cm,cmap=cmap)
    plt.title(title)
    fig.colorbar(cax)
    if labels:
        plt.xticks(range(len(labels)), labels)
        plt.yticks(range(len(labels)), labels)

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('confusionmatrix.png')
like image 31
arman Avatar answered Oct 26 '22 16:10

arman