Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How predict_proba in sklearn produces two columns? what are their significance?

I was using simple logistic regression to predict a problem and trying to plot the precision_recall_curve and the roc_curve with predict_proba(X_test). I checked the docstring of predict_proba but hadn't had much details on how it works. I was having bad input every time and checked that y_test, predict_proba(X_test) doesn't match. Finally discovered predict_proba() produces 2 columns and people use the second.

It would be really helpful if someone can give an explanation how it produces two columns and their significance. TIA.

like image 688
Md. Rezaul Karim Avatar asked Apr 27 '19 17:04

Md. Rezaul Karim


2 Answers

predict_proba() produces output of shape (N, k) where N is the number of datapoints and k is the number of classes you're trying to classify. It seems you have two classes and hence you have 2 columns. Say your labels(classes) are ["healthy", "diabetes"], if a datapoint is predicted to have 80% chance of being diabetic and consequently 20% chance of being healthy, then your output row for that point will be [0.2, 0.8] to reflect these probabilities. In general you can go through the predicted array and get probabilities for the k-th class with model.predict_proba(X)[:,k-1]

As far as plotting you can do the following for precision_recall_curve:

predicted = logisticReg.predict_proba(X_test)
precision, recall, threshold = precision_recall_curve(y_test, predicted[:,1])

For ROC:

predicted = logisticReg.predict_proba(X_test)
fpr, tpr, thresholds = precision_recall_curve(y_test, predicted[:,1])

Notice that this will change for multi-label classification. You can find an example of that on the sklearn docs here

like image 120
Turtalicious Avatar answered Sep 18 '22 13:09

Turtalicious


We can distinguish between the classifiers using the classifier classes. if the classifier name is model then model.classes_ will give the distinct classes.

like image 38
Shibendu Avatar answered Sep 19 '22 13:09

Shibendu