Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding a corresponding leaf node for each data point in a decision tree (scikit-learn)

I'm using decision tree classifier from the scikit-learn package in python 3.4, and I want to get the corresponding leaf node id for each of my input data point.

For example, my input might look like this:

array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2]])

and let's suppose the corresponding leaf nodes are 16, 5 and 45 respectively. I want my output to be:

leaf_node_id = array([16, 5, 45])

I have read through the scikit-learn mailing list and related questions on SF but I still can't get it to work. Here is some hint I found on the mailing list, but still does not work.

http://sourceforge.net/p/scikit-learn/mailman/message/31728624/

At the end of the day, I just want to have a function GetLeafNode(clf, X_valida) such that its output is a list of corresponding leaf nodes. Below is the code that reproduces the error I received. So, any suggestion will be very appreciated.

from sklearn.datasets import load_iris
from sklearn import tree

# load data and divide it to train and validation
iris = load_iris()

num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]

y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]

# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)

# Now I want to know the corresponding leaf node id for each of my training data point
clf.tree_.apply(X_train)

# This gives the error message below:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-17-2ecc95213752> in <module>()
----> 1 clf.tree_.apply(X_train)

_tree.pyx in sklearn.tree._tree.Tree.apply (sklearn/tree/_tree.c:19595)()

ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double'
like image 953
neobot Avatar asked May 23 '15 01:05

neobot


1 Answers

Since scikit-learn 0.17 you can use the apply method of the DecisionTree object to get the index of the leaves where the data points end in the tree. Building on neobot's answer:

from sklearn.datasets import load_iris
from sklearn import tree

# load data and divide it to train and validation
iris = load_iris()

num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]

y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]

# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)

# Compute the leaf node id for each of my training data points
clf.apply(X_train)

Produces the output

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2])
like image 121
albarji Avatar answered Sep 21 '22 16:09

albarji