Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I find which attributes my tree splits on, when using scikit-learn?

I have been exploring scikit-learn, making decision trees with both entropy and gini splitting criteria, and exploring the differences.

My question, is how can I "open the hood" and find out exactly which attributes the trees are splitting on at each level, along with their associated information values, so I can see where the two criterion make different choices?

So far, I have explored the 9 methods outlined in the documentation. They don't appear to allow access to this information. But surely this information is accessible? I'm envisioning a list or dict that has entries for node and gain.

Thanks for your help and my apologies if I've missed something completely obvious.

like image 538
tumultous_rooster Avatar asked Nov 23 '13 00:11

tumultous_rooster


People also ask

How do you determine which attributes to split in trees?

Steps to split a decision tree using Information Gain: For each split, individually calculate the entropy of each child node. Calculate the entropy of each split as the weighted average entropy of child nodes. Select the split with the lowest entropy or highest information gain.

What is number of splits in decision tree?

For k classes there are 2k–1 – 1 splits, which is computationally prohibitive if k is a large number. If there are many classes, they may be ordered according to their average output value. We can the make a binary split into two groups of the ordered classes. This means there are k – 1 possible splits for k classes.

Which attributes should be selected as the splitting attribute in decision tree?

The maximum reduction in impurity or max Gini index is selected as the best attribute for splitting.


3 Answers

Directly from the documentation ( http://scikit-learn.org/0.12/modules/tree.html ):

from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)

StringIO module is no longer supported in Python3, instead import io module.

There is also the tree_ attribute in your decision tree object, which allows the direct access to the whole structure.

And you can simply read it

clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values

for more details look at the source code of export method

In general you can use the inspect module

from inspect import getmembers
print( getmembers( clf.tree_ ) )

to get all the object's elements

Decision tree visualization from sklearn docs

like image 73
lejlot Avatar answered Oct 13 '22 19:10

lejlot


If you just want a quick look at which what is going on in the tree, try:

zip(X.columns[clf.tree_.feature], clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right)

where X is the data frame of independent variables and clf is the decision tree object. Notice that clf.tree_.children_left and clf.tree_.children_right together contain the order that the splits were made (each one of these would correspond to an arrow in the graphviz visualization).

like image 45
Daniel Gibson Avatar answered Oct 13 '22 18:10

Daniel Gibson


Scikit learn introduced a delicious new method called export_text in version 0.21 (May 2019) to view all the rules from a tree. Documentation here.

Once you've fit your model, you just need two lines of code. First, import export_text:

from sklearn.tree.export import export_text

Second, create an object that will contain your rules. To make the rules look more readable, use the feature_names argument and pass a list of your feature names. For example, if your model is called model and your features are named in a dataframe called X_train, you could create an object called tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Then just print or save tree_rules. Your output will look like this:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
like image 32
yzerman Avatar answered Oct 13 '22 18:10

yzerman