Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

ROC for multiclass classification

I'm doing different text classification experiments. Now I need to calculate the AUC-ROC for each task. For the binary classifications, I already made it work with this code:

scaler = StandardScaler(with_mean=False)  enc = LabelEncoder() y = enc.fit_transform(labels)  feat_sel = SelectKBest(mutual_info_classif, k=200)  clf = linear_model.LogisticRegression()  pipe = Pipeline([('vectorizer', DictVectorizer()),                  ('scaler', StandardScaler(with_mean=False)),                  ('mutual_info', feat_sel),                  ('logistregress', clf)]) y_pred = model_selection.cross_val_predict(pipe, instances, y, cv=10) # instances is a list of dictionaries  #visualisation ROC-AUC  fpr, tpr, thresholds = roc_curve(y, y_pred) auc = auc(fpr, tpr) print('auc =', auc)  plt.figure() plt.title('Receiver Operating Characteristic') plt.plot(fpr, tpr, 'b', label='AUC = %0.2f'% auc) plt.legend(loc='lower right') plt.plot([0,1],[0,1],'r--') plt.xlim([-0.1,1.2]) plt.ylim([-0.1,1.2]) plt.ylabel('True Positive Rate') plt.xlabel('False Positive Rate') plt.show() 

But now I need to do it for the multiclass classification task. I read somewhere that I need to binarize the labels, but I really don't get how to calculate ROC for multiclass classification. Tips?

like image 346
Bambi Avatar asked Jul 26 '17 16:07

Bambi


2 Answers

As people mentioned in comments you have to convert your problem into binary by using OneVsAll approach, so you'll have n_class number of ROC curves.

A simple example:

from sklearn.metrics import roc_curve, auc from sklearn import datasets from sklearn.multiclass import OneVsRestClassifier from sklearn.svm import LinearSVC from sklearn.preprocessing import label_binarize from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt  iris = datasets.load_iris() X, y = iris.data, iris.target  y = label_binarize(y, classes=[0,1,2]) n_classes = 3  # shuffle and split training and test sets X_train, X_test, y_train, y_test =\     train_test_split(X, y, test_size=0.33, random_state=0)  # classifier clf = OneVsRestClassifier(LinearSVC(random_state=0)) y_score = clf.fit(X_train, y_train).decision_function(X_test)  # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes):     fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])     roc_auc[i] = auc(fpr[i], tpr[i])  # Plot of a ROC curve for a specific class for i in range(n_classes):     plt.figure()     plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])     plt.plot([0, 1], [0, 1], 'k--')     plt.xlim([0.0, 1.0])     plt.ylim([0.0, 1.05])     plt.xlabel('False Positive Rate')     plt.ylabel('True Positive Rate')     plt.title('Receiver operating characteristic example')     plt.legend(loc="lower right")     plt.show() 

enter image description hereenter image description hereenter image description here

like image 147
omdv Avatar answered Sep 18 '22 15:09

omdv


This works for me and is nice if you want them on the same plot. It is similar to @omdv's answer but maybe a little more succinct.

def plot_multiclass_roc(clf, X_test, y_test, n_classes, figsize=(17, 6)):     y_score = clf.decision_function(X_test)      # structures     fpr = dict()     tpr = dict()     roc_auc = dict()      # calculate dummies once     y_test_dummies = pd.get_dummies(y_test, drop_first=False).values     for i in range(n_classes):         fpr[i], tpr[i], _ = roc_curve(y_test_dummies[:, i], y_score[:, i])         roc_auc[i] = auc(fpr[i], tpr[i])      # roc for each class     fig, ax = plt.subplots(figsize=figsize)     ax.plot([0, 1], [0, 1], 'k--')     ax.set_xlim([0.0, 1.0])     ax.set_ylim([0.0, 1.05])     ax.set_xlabel('False Positive Rate')     ax.set_ylabel('True Positive Rate')     ax.set_title('Receiver operating characteristic example')     for i in range(n_classes):         ax.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f) for label %i' % (roc_auc[i], i))     ax.legend(loc="best")     ax.grid(alpha=.4)     sns.despine()     plt.show()  plot_multiclass_roc(full_pipeline, X_test, y_test, n_classes=16, figsize=(16, 10)) 
like image 31
pabz Avatar answered Sep 20 '22 15:09

pabz