Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

compute maximum f1 score using precision_recall_curve?

For a simple binary classification problem, I would like to find what threshold setting maximizes the f1 score, which is the harmonic mean of precision and recall. Is there any built-in in scikit learn that does this? Right now, I am simply calling

precision, recall, thresholds = precision_recall_curve(y_test, y_test_predicted_probas)

And then, I can compute the f1 score using the information at each index in the triplet of arrays:

curr_f1 = compute_f1(precision[index], recall[index])

Is there a better way of doing this, or is this how the library was intended to be used? Thanks.

like image 884
information_interchange Avatar asked Jul 16 '19 15:07

information_interchange


People also ask

What is the maximum value of F1 score?

A binary classification task. Clearly, the higher the F1 score the better, with 0 being the worst possible and 1 being the best. Beyond this, most online sources don't give you any idea of how to interpret a specific F1 score.

How is F1 score from precision and recall calculated?

For example, a perfect precision and recall score would result in a perfect F-Measure score: F-Measure = (2 * Precision * Recall) / (Precision + Recall) F-Measure = (2 * 1.0 * 1.0) / (1.0 + 1.0) F-Measure = (2 * 1.0) / 2.0.


2 Answers

After calculating the precision, recall and threshold scores you get NumPy arrays.
Just use the NumPy functions to find the threshold that maximizes the F1-Score:

f1_scores = 2*recall*precision/(recall+precision)
print('Best threshold: ', thresholds[np.argmax(f1_scores)])
print('Best F1-Score: ', np.max(f1_scores))
like image 108
Mike Alexander Doepking Avatar answered Sep 23 '22 08:09

Mike Alexander Doepking


Sometimes precision_recall_curve picks a few thresholds that are too high for the data so you end up with points where both precision and recall are zero. This can result in nans when computing F1 scores. To ensure correct output, use np.divide to only divide where the denominator is nonzero:

precision, recall, thresholds = precision_recall_curve(y_test, y_test_predicted_probas)
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom!=0))
max_f1 = np.max(f1_scores)
max_f1_thresh = thresholds[np.argmax(f1_scores)]
like image 42
Craig Bidstrup Avatar answered Sep 20 '22 08:09

Craig Bidstrup