Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scikit-Learn Decision Tree: Probability of prediction being a or b?

I have a basic decision tree classifier with Scikit-Learn:

#Used to determine men from women based on height and shoe size

from sklearn import tree

#height and shoe size
X = [[65,9],[67,7],[70,11],[62,6],[60,7],[72,13],[66,10],[67,7.5]]

Y=["male","female","male","female","female","male","male","female"]

#creating a decision tree
clf = tree.DecisionTreeClassifier()

#fitting the data to the tree
clf.fit(X, Y)

#predicting the gender based on a prediction
prediction = clf.predict([68,9])

#print the predicted gender
print(prediction)

When I run the program, it always outputs either "male" or "female", but how would I be able to see the probability of the prediction being male or female? For example, the prediction above returns "male", but how would I get it to print the probability of the prediction being male?

Thanks!

like image 919
Davis Keene Avatar asked Nov 12 '17 17:11

Davis Keene


3 Answers

You can do something like the following:

from sklearn import tree

#load data
X = [[65,9],[67,7],[70,11],[62,6],[60,7],[72,13],[66,10],[67,7.5]]
Y=["male","female","male","female","female","male","male","female"]

#build model
clf = tree.DecisionTreeClassifier()

#fit
clf.fit(X, Y)

#predict
prediction = clf.predict([[68,9],[66,9]])

#probabilities
probs = clf.predict_proba([[68,9],[66,9]])

#print the predicted gender
print(prediction)
print(probs)

Theory

The result of clf.predict_proba(X) is: The predicted class probability which is the fraction of samples of the same class in a leaf.

Interpretation of the results:

The first print returns ['male' 'male'] so the data [[68,9],[66,9]] are predicted as males.

The second print returns:

[[ 0. 1.] [ 0. 1.]]

This means that the data were predicted as males and this is reported by the ones in the second column.

To see the order of the classes use: clf.classes_

This returns: ['female', 'male']

like image 115
seralouk Avatar answered Oct 04 '22 08:10

seralouk


Sounds like you need to read the sklearn documentation for DecisionTreeClassifier and see:

predict_proba(X[, check_input])
like image 42
Coloane Avatar answered Oct 04 '22 07:10

Coloane


the answer in my top is correct, you are getting binary output because your tree is complete and not truncate in order to make your tree weaker, you can use max_depth to a lower depth so probability won't be like [0. 1.] it will look like [0.25 0.85] another problem here is that the dataset is very small and easy to solve so better to use a more complex dataset some link that might make this more clear for you mate https://rpmcruz.github.io/machine%20learning/2018/02/09/probabilities-trees.html https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier.predict_proba

like image 41
Saman Ariyanpour Avatar answered Oct 04 '22 08:10

Saman Ariyanpour