I have trained a decision tree using a dataset. Now I want to see which samples fall under which leaf of the tree.
From here I want the red circled samples.
I am using Python's Sklearn's implementation of decision tree .
A decision tree is a flowchart-like structure in which each internal node represents a "test" on an attribute (e.g. whether a coin flip comes up heads or tails), each branch represents the outcome of the test, and each leaf node represents a class label (decision taken after computing all attributes).
Count the number of observations in the k 'th leaf. Divide by the total number of observations in all the leaf nodes. After looping over all leafs, the result of each leaf is summed for a final gini impurity.
The maximum depth of a decision tree is simply the largest possible length between the root to a leaf. A tree of maximum length kk can have at most 2^k2k leaves.
The process of splitting a single node into many nodes is known as splitting. A leaf node, also known as a terminal node, is a node that does not break into other nodes. A branch, sometimes known as a sub-tree, is a section of a decision tree. Splitting is not the only concept that is diametrically opposite it.
If you want only the leaf for each sample you can just use
clf.apply(iris.data)
array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 14, 5, 5, 5, 5, 5, 5, 10, 5, 5, 5, 5, 5, 10, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 16, 16, 16, 16, 16, 16, 6, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 8, 16, 16, 16, 16, 16, 16, 15, 16, 16, 11, 16, 16, 16, 8, 8, 16, 16, 16, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16])
If you want to get all samples for each node you could calculate all the decision paths with
dec_paths = clf.decision_path(iris.data)
Then loop over the decision paths, convert them to arrays with toarray()
and check whether they belong to a node or not. Everything is stored in a defaultdict
where the key is the node number and the values are the sample number.
for d, dec in enumerate(dec_paths):
for i in range(clf.tree_.node_count):
if dec.toarray()[0][i] == 1:
samples[i].append(d)
Complete code
import sklearn.datasets
import sklearn.tree
import collections
clf = sklearn.tree.DecisionTreeClassifier(random_state=42)
iris = sklearn.datasets.load_iris()
clf = clf.fit(iris.data, iris.target)
samples = collections.defaultdict(list)
dec_paths = clf.decision_path(iris.data)
for d, dec in enumerate(dec_paths):
for i in range(clf.tree_.node_count):
if dec.toarray()[0][i] == 1:
samples[i].append(d)
Output
print(samples[13])
[70, 126, 138]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With