Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot scikit learn classification report?

Is it possible to plot with matplotlib scikit-learn classification report?. Let's assume I print the classification report like this:

print '\n*Classification Report:\n', classification_report(y_test, predictions)     confusion_matrix_graph = confusion_matrix(y_test, predictions) 

and I get:

Clasification Report:              precision    recall  f1-score   support            1       0.62      1.00      0.76        66           2       0.93      0.93      0.93        40           3       0.59      0.97      0.73        67           4       0.47      0.92      0.62       272           5       1.00      0.16      0.28       413  avg / total       0.77      0.57      0.49       858 

How can I "plot" the avobe chart?.

like image 408
tumbleweed Avatar asked Jan 28 '15 19:01

tumbleweed


People also ask

How do I find my Sklearn classification report?

There are four ways to check if the predictions are right or wrong: TN / True Negative: when a case was negative and predicted negative. TP / True Positive: when a case was positive and predicted positive. FN / False Negative: when a case was positive but predicted negative.

How do you plot classified data in Python?

Whenever you plot a point, you have to give it the x and y coordinate for that point. Currently you're trying to plot two x values per y value, but it doesn't know how to map them. With your current code, the easiest thing would be to duplicate the y values for the second row of x values and plot all of them that way.

What is support in Sklearn classification report?

The support is the number of samples of the true response that lie in that class. You can find documentation on both measures in the sklearn documentation. The last line gives a weighted average of precision, recall and f1-score where the weights are the support values.


2 Answers

Expanding on Bin's answer:

import matplotlib.pyplot as plt import numpy as np  def show_values(pc, fmt="%.2f", **kw):     '''     Heatmap with text in each cell with matplotlib's pyplot     Source: https://stackoverflow.com/a/25074150/395857      By HYRY     '''     from itertools import izip     pc.update_scalarmappable()     ax = pc.get_axes()     #ax = pc.axes# FOR LATEST MATPLOTLIB     #Use zip BELOW IN PYTHON 3     for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):         x, y = p.vertices[:-2, :].mean(0)         if np.all(color[:3] > 0.5):             color = (0.0, 0.0, 0.0)         else:             color = (1.0, 1.0, 1.0)         ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)   def cm2inch(*tupl):     '''     Specify figure size in centimeter in matplotlib     Source: https://stackoverflow.com/a/22787457/395857     By gns-ank     '''     inch = 2.54     if type(tupl[0]) == tuple:         return tuple(i/inch for i in tupl[0])     else:         return tuple(i/inch for i in tupl)   def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels, figure_width=40, figure_height=20, correct_orientation=False, cmap='RdBu'):     '''     Inspired by:     - https://stackoverflow.com/a/16124677/395857      - https://stackoverflow.com/a/25074150/395857     '''      # Plot it out     fig, ax = plt.subplots()         #c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)     c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap)      # put the major ticks at the middle of each cell     ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)     ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)      # set tick labels     #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)     ax.set_xticklabels(xticklabels, minor=False)     ax.set_yticklabels(yticklabels, minor=False)      # set title and x/y labels     plt.title(title)     plt.xlabel(xlabel)     plt.ylabel(ylabel)            # Remove last blank column     plt.xlim( (0, AUC.shape[1]) )      # Turn off all the ticks     ax = plt.gca()         for t in ax.xaxis.get_major_ticks():         t.tick1On = False         t.tick2On = False     for t in ax.yaxis.get_major_ticks():         t.tick1On = False         t.tick2On = False      # Add color bar     plt.colorbar(c)      # Add text in each cell      show_values(c)      # Proper orientation (origin at the top left instead of bottom left)     if correct_orientation:         ax.invert_yaxis()         ax.xaxis.tick_top()             # resize      fig = plt.gcf()     #fig.set_size_inches(cm2inch(40, 20))     #fig.set_size_inches(cm2inch(40*4, 20*4))     fig.set_size_inches(cm2inch(figure_width, figure_height))    def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):     '''     Plot scikit-learn classification report.     Extension based on https://stackoverflow.com/a/31689645/395857      '''     lines = classification_report.split('\n')      classes = []     plotMat = []     support = []     class_names = []     for line in lines[2 : (len(lines) - 2)]:         t = line.strip().split()         if len(t) < 2: continue         classes.append(t[0])         v = [float(x) for x in t[1: len(t) - 1]]         support.append(int(t[-1]))         class_names.append(t[0])         print(v)         plotMat.append(v)      print('plotMat: {0}'.format(plotMat))     print('support: {0}'.format(support))      xlabel = 'Metrics'     ylabel = 'Classes'     xticklabels = ['Precision', 'Recall', 'F1-score']     yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]     figure_width = 25     figure_height = len(class_names) + 7     correct_orientation = False     heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)   def main():     sampleClassificationReport = """             precision    recall  f1-score   support            Acacia       0.62      1.00      0.76        66           Blossom       0.93      0.93      0.93        40           Camellia       0.59      0.97      0.73        67           Daisy       0.47      0.92      0.62       272           Echium       1.00      0.16      0.28       413          avg / total       0.77      0.57      0.49       858"""       plot_classification_report(sampleClassificationReport)     plt.savefig('test_plot_classif_report.png', dpi=200, format='png', bbox_inches='tight')     plt.close()  if __name__ == "__main__":     main()     #cProfile.run('main()') # if you want to do some profiling 

outputs:

enter image description here

Example with more classes (~40):

enter image description here

like image 107
Franck Dernoncourt Avatar answered Sep 20 '22 03:09

Franck Dernoncourt


No string processing + sns.heatmap

The following solution uses the output_dict=True option in classification_report to get a dictionary and then a heat map is drawn using seaborn to the dataframe created from the dictionary.


import numpy as np import seaborn as sns from sklearn.metrics import classification_report import pandas as pd 

Generating data. Classes: A,B,C,D,E,F,G,H,I

true = np.random.randint(0, 10, size=100) pred = np.random.randint(0, 10, size=100) labels = np.arange(10) target_names = list("ABCDEFGHI") 

Call classification_report with output_dict=True

clf_report = classification_report(true,                                    pred,                                    labels=labels,                                    target_names=target_names,                                    output_dict=True) 

Create a dataframe from the dictionary and plot a heatmap of it.

# .iloc[:-1, :] to exclude support sns.heatmap(pd.DataFrame(clf_report).iloc[:-1, :].T, annot=True) 

enter image description here

like image 25
akilat90 Avatar answered Sep 19 '22 03:09

akilat90