Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In scikit's precision_recall_curve, why does thresholds have a different dimension from recall and precision?

I want to see how precision and recall vary with the threshold (not just with each other)

model = RandomForestClassifier(500, n_jobs = -1);  
model.fit(X_train, y_train);  
probas = model.predict_proba(X_test)[:, 1]  
precision, recall, thresholds = precision_recall_curve(y_test, probas)  
print len(precision)   
print len(thresholds)  

Returns:

283  
282

I can, therefore, not plot them together. Any clues as to why this might be the case?

like image 219
sapo_cosmico Avatar asked Jul 26 '15 15:07

sapo_cosmico


1 Answers

For this problem, the last precision and the recall value should be ignored The last precision and recall values are always 1. and 0. respectively and do not have a corresponding threshold.

For example here is a solution :

def plot_precision_recall_vs_threshold(precisions, recall, thresholds): 
    fig = plt.figure(figsize= (8,5))
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recall[:-1], "g-", label="Recall")
    plt.legend()

plot_precision_recall_vs_threshold(precision, recall, thresholds)

These values should are there so that the plot starts at the y-axis (x=0) when you are plotting precision vs recall.

like image 176
Suvam Avatar answered Sep 22 '22 23:09

Suvam