Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot precision and recall of multiclass classifier?

I'm using scikit learn, and I want to plot the precision and recall curves. the classifier I'm using is RandomForestClassifier. All the resources in the documentations of scikit learn uses binary classification. Also, can I plot a ROC curve for multiclass?

Also, I only found for SVM for multilabel and it has a decision_function which RandomForest doesn't have

like image 638
John Sall Avatar asked May 11 '19 12:05

John Sall


People also ask

How do you plot precision and recall for multiclass?

One ROC curve can be drawn per label, but one can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging). Therefore, you should binarize the output and consider precision-recall and roc curves for each class.

Can precision and recall be used for multiclass classification?

In Python's scikit-learn library (also known as sklearn), you can easily calculate the precision and recall for each class in a multi-class classifier. A convenient function to use here is sklearn. metrics.

How computation of precision and recall is different in multiclass dataset?

Specifically, you learned: Precision quantifies the number of positive class predictions that actually belong to the positive class. Recall quantifies the number of positive class predictions made out of all positive examples in the dataset.


1 Answers

From scikit-learn documentation:

  • Precision-Recall:

Precision-recall curves are typically used in binary classification to study the output of a classifier. In order to extend the precision-recall curve and average precision to multi-class or multi-label classification, it is necessary to binarize the output. One curve can be drawn per label, but one can also draw a precision-recall curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).

  • Receiver Operating Characteristic (ROC):

ROC curves are typically used in binary classification to study the output of a classifier. In order to extend ROC curve and ROC area to multi-class or multi-label classification, it is necessary to binarize the output. One ROC curve can be drawn per label, but one can also draw a ROC curve by considering each element of the label indicator matrix as a binary prediction (micro-averaging).

Therefore, you should binarize the output and consider precision-recall and roc curves for each class. Moreover, you are going to use predict_proba to get class probabilities.

I divide the code into three parts:

  1. general settings, learning and prediction
  2. precision-recall curve
  3. ROC curve

1. general settings, learning and prediction

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_openml("mnist_784")
y = mnist.target
y = y.astype(np.uint8)
n_classes = len(set(y))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)

2. precision-recall curve

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()

enter image description here

3. ROC curve

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

enter image description here

like image 151
sentence Avatar answered Sep 20 '22 11:09

sentence