Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Difference between predict vs predict_proba in scikit-learn

Suppose I have created a model, and my target variable is either 0, 1 or 2. It seems that if I use predict, the answer is either of 0, or 1 or 2. But if I use predict_proba, I get a row with 3 cols for each row as follows, for example

   model = ... Classifier       # It could be any classifier
   m1 = model.predict(mytest)
   m2= model.predict_proba(mytest)

   # Now suppose  m1[3] = [0.6, 0.2, 0.2]

Suppose I use both predict and predict_proba. If in index 3, I get the above result with the result of predict_proba, in index 3 of the result of predict I should see 0. Is this the case? I am trying to understand how using both predict and predict_proba on the same model relate to each other.

like image 554
Fabiana Avatar asked Apr 13 '20 09:04

Fabiana


1 Answers

  • predict() is used to predict the actual class (in your case one of 0, 1, or 2).
  • predict_proba() is used to predict the class probabilities

From the example output that you shared,

  • predict() would output class 0 since the class probability for 0 is 0.6.
  • [0.6, 0.2, 0.2] is the output of predict_proba that simply denotes that the class probability for classes 0, 1, and 2 are 0.6, 0.2, and 0.2 respectively.

Now as the documentation mentions for predict_proba, the resulting array is ordered based on the labels you've been using:

The returned estimates for all classes are ordered by the label of classes.

Therefore, in your case where your class labels are [0, 1, 2], the corresponding output of predict_proba will contain the corresponding probabilities. 0.6 is the probability of the instance to be classified as 0 and 0.2 are the probabilities that the instance is categorised as 1 and 2 respectively.


For a more comprehensive explanation, refer to the article What is the difference between predict() and predict_proba() in scikit-learn on TDS.

like image 146
Giorgos Myrianthous Avatar answered Sep 28 '22 16:09

Giorgos Myrianthous