Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Where does scikit-learn hold the decision labels of each leaf node in its tree structure?

I have trained a random forest model using scikit-learn and now I want to save its tree structures in a text file so I can use it elsewhere. According to this link a tree object consist of a number of parallel arrays each one hold some information about different nodes of the tree (ex. left child, right child, what feature it examines,...) . However there seems to be no information about the class label corresponding to each leaf node! It's even not mentioned in the examples provided in the link above.

Does anyone know where are the class labels stored in the scikit-learn decision tree structure?

like image 726
whoAmI Avatar asked May 24 '17 12:05

whoAmI


People also ask

What decision tree algorithm does Scikit-learn use?

All decision trees use np. float32 arrays internally. If training data is not in this format, a copy of the dataset will be made.

Which of the following module of Sklearn is used for dealing with decision trees?

Sklearn Module − The Scikit-learn library provides the module name DecisionTreeRegressor for applying decision trees on regression problems.

What does the leaf node in a decision tree represents?

Each leaf node represents a class. It does not require any domain knowledge. It is easy to comprehend. The learning and classification steps of a decision tree are simple and fast.


1 Answers

Take a look at the docs for sklearn.tree.DecisionTreeClassifier.tree_.value:

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()

clf.fit(iris.data, iris.target)

print(clf.classes_)

[0, 1, 2]

print(clf.tree_.value)

[[[ 50.  50.  50.]]

 [[ 50.   0.   0.]]

 [[  0.  50.  50.]]

 [[  0.  49.   5.]]

 [[  0.  47.   1.]]

 [[  0.  47.   0.]]

 [[  0.   0.   1.]]

 [[  0.   2.   4.]]

 [[  0.   0.   3.]]

 [[  0.   2.   1.]]

 [[  0.   2.   0.]]

 [[  0.   0.   1.]]

 [[  0.   1.  45.]]

 [[  0.   1.   2.]]

 [[  0.   0.   2.]]

 [[  0.   1.   0.]]

 [[  0.   0.  43.]]]

Each row in clf.tree_.value "contains the constant prediction value of each node," (help(clf.tree_)) which corresponds index-to-index to clf.classes_.

See this answer for (barely) more details.

like image 54
blacksite Avatar answered Nov 11 '22 23:11

blacksite