Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I get all Gini indices in my decision tree?

I have made a decision tree using sklearn, here, under the SciKit learn DL package, viz. sklearn.tree.DecisionTreeClassifier().fit(x,y).

How do I get the gini indices for all possible nodes at each step? graphviz only gives me the gini index of the node with the lowest gini index, ie the node used for split.

For example, the image below (from graphviz) tells me the gini score of the Pclass_lowVMid right index which is 0.408, but not the gini index of the Pclass_lower or Sex_male at that step. I just know the Gini index of Pclass_lower and Sex_male must be greater than (0.408*0.7 + 0) but that's it.

decision tree

like image 393
vivian Avatar asked Dec 10 '18 03:12

vivian


2 Answers

Using export_graphviz shows impurity for all nodes, at least in version 0.20.1.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from graphviz import Source

data = load_iris()
X, y = data.data, data.target

clf = DecisionTreeClassifier(max_depth=2, random_state=42)
clf.fit(X, y)

graph = Source(export_graphviz(clf, out_file=None, feature_names=data.feature_names))
graph.format = 'png'
graph.render('dt', view=True);

enter image description here

The impurity values for all nodes are also accessible in the impurity attribute of the tree.

clf.tree_.impurity
array([0.66666667, 0.        , 0.5       , 0.16803841, 0.04253308])
like image 55
Kevin Avatar answered Sep 26 '22 14:09

Kevin


Gini index of pclass node = gini index of left node * (no. of samples in left node/ no. samples at left node + no. of samples at right node) + gini index of right node * ( no. of samples in left node/ no. samples at left node + no. of samples at right node) So here it will be

Gini index of pclass = 0 + .408 *(7/10) = 0.2856
like image 24
Rishi Bansal Avatar answered Sep 24 '22 14:09

Rishi Bansal