Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting the maximum accuracy for a binary probabilistic classifier in scikit-learn

Is there any built-in function to get the maximum accuracy for a binary probabilistic classifier in scikit-learn?

E.g. to get the maximum F1-score I do:

# AUCPR
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_score)    
auprc  = sklearn.metrics.auc(recall, precision)
max_f1 = 0
for r, p, t in zip(recall, precision, thresholds):
    if p + r == 0: continue
    if (2*p*r)/(p + r) > max_f1:
        max_f1 = (2*p*r)/(p + r) 
        max_f1_threshold = t

I could compute the maximum accuracy in a similar fashion:

accuracies = []
thresholds = np.arange(0,1,0.1)
for threshold in thresholds:
    y_pred = np.greater(y_score, threshold).astype(int)
    accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)
    accuracies.append(accuracy)

accuracies = np.array(accuracies)
max_accuracy = accuracies.max() 
max_accuracy_threshold =  thresholds[accuracies.argmax()]

but I wonder whether there is any built-in function.

like image 412
Franck Dernoncourt Avatar asked Jul 18 '15 06:07

Franck Dernoncourt


People also ask

How do you calculate accuracy in classification?

Accuracy is a metric used in classification problems used to tell the percentage of accurate predictions. We calculate it by dividing the number of correct predictions by the total number of predictions.

Which library is used to check the accuracy of a predictive model in Sklearn?

Accuracy classification score. In multilabel classification, this function computes subset accuracy: the set of labels predicted for a sample must exactly match the corresponding set of labels in y_true.


1 Answers

from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(y_true, probs)
accuracy_scores = []
for thresh in thresholds:
    accuracy_scores.append(accuracy_score(y_true, [m > thresh for m in probs]))

accuracies = np.array(accuracy_scores)
max_accuracy = accuracies.max() 
max_accuracy_threshold =  thresholds[accuracies.argmax()]

like image 158
ahoosh Avatar answered Oct 12 '22 09:10

ahoosh