I am using a scikit-learn DecissionTreeClassifier on a 3 class dataset. After I fit the classifier I access all leaf nodes on the tree_ attribute in order to get the amount of instances that end up in a given node for each class.
clf = tree.DecisionTreeClassifier(max_depth=5)
clf.fit(X, y)
# lets assume there is a leaf node with id 5
print clf.tree_.value[5]
This will print out:
>>> array([[ 0., 1., 68.]])
but ... how do I know which position in that array belongs to which class ? The classifier has a classes_ attribute which is also a list
>>> clf.classes_
array(['CLASS_1', 'CLASS_2', 'CLASS_3'], dtype=object)
Maybe index 1 on the value array matches the class on index 1 of the classes array and so on?
Asked about this on the scikit-learm mailing list and my guess was right. Turns out the index 1 on the value array matches the class on index 1 of the classes array and so on
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With