Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotting a ROC curve in scikit yields only 3 points

Tags:

TLDR: scikit's roc_curve function is only returning 3 points for a certain dataset. Why could this be, and how do we control how many points to get back?

I'm trying to draw a ROC curve, but consistently get a "ROC triangle".

lr = LogisticRegression(multi_class = 'multinomial', solver = 'newton-cg') y = data['target'].values X = data[['feature']].values  model = lr.fit(X,y)  # get probabilities for clf probas_ = model.predict_log_proba(X) 

Just to make sure the lengths are ok:

print len(y) print len(probas_[:, 1]) 

Returns 13759 on both.

Then running:

false_pos_rate, true_pos_rate, thresholds = roc_curve(y, probas_[:, 1]) print false_pos_rate 

returns [ 0. 0.28240129 1. ]

If I call threasholds, I get array([ 0.4822225 , -0.5177775 , -0.84595197]) (always only 3 points).

It is therefore no surprise that my ROC curve looks like a triangle.

What I cannot understand is why scikit's roc_curve is only returning 3 points. Help hugely appreciated.

enter image description here

like image 925
sapo_cosmico Avatar asked May 05 '15 11:05

sapo_cosmico


People also ask

Why does the ROC curve have 3 points?

The three-point Receiving Operating Characteristic curve. This happens because the ROC Curve is a threshold independent metric, i.e. it will build the confusion matrix for all possible threshold values, and we need the continuous output to build it properly.

How do you plot points on a ROC curve?

To plot the ROC curve, we need to calculate the TPR and FPR for many different thresholds (This step is included in all relevant libraries as scikit-learn ). For each threshold, we plot the FPR value in the x-axis and the TPR value in the y-axis. We then join the dots with a line. That's it!


1 Answers

The number of points depend on the number of unique values in the input. Since the input vector has only 2 unique values, the function gives correct output.

like image 61
pyan Avatar answered Sep 20 '22 17:09

pyan