Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Trying to print out the decision tree for a forest from scikit-learn ensemble

I am trying to print out the decision tree for a forest from scikit-learn ensemble: For example for a DecisionTreeClassifier, I would use:

from sklearn import tree  
clf = tree.DecisionTreeClassifier( criterion ='entropy', max_depth = 3,    
min_samples_leaf = 
clf = clf.fit( X_train, y_train) #Input this to analyze the training set.

import pydot, StringIO
dot_data = StringIO.StringIO() 
tree.export_graphviz( clf, out_file = dot_data,    
feature_names =[' age', 'sex', 'first_class', 'second_class', 'third_class'])
graph = pydot.graph_from_dot_data( dot_data.getvalue())
graph.write_png('visualtree.png') 
from IPython.core.display import Image
Image( filename =visualtree.png')

I tried a similar approach for Random Forest Regressor (see below and got an error)

# Fit regression model
from sklearn.ensemble import RandomForestRegressor
rfr_1 = RandomForestRegressor(n_estimators=10, max_depth=5)
rfr_1.fit(X, y)

from sklearn.ensemble import*
import pydot, StringIO
dot_data = StringIO.StringIO() 
ensemble.export_graphviz( rf1, out_file = dot_data,    
feature_names =[' Temperature', 'Translator Bacteria'])
graph = pydot.graph_from_dot_data( dot_data.getvalue())
graph.write_png('fish.png') 
from IPython.core.display import Image

Image( filename ='fish.png')

File "randomforestregressor.py", line 45, in ensemble.export_graphviz( rf1, out_file = dot_data,
NameError: name 'ensemble' is not defined

How would I accomplish this? thanks!

like image 946
Chris Rigano Avatar asked Dec 30 '13 15:12

Chris Rigano


1 Answers

The error message is pretty explicit:

File "randomforestregressor.py", line 45, in ensemble.export_graphviz( rf1, out_file = dot_data, NameError: name 'ensemble' is not defined

You access a variable named ensemble in your script line 45 but you never define such a variable. In your case you probably intended that variable to point to the sklearn.ensemble package:

from sklearn import ensemble

However if you do this you will likely get an AttributeError as the sklearn.ensemble package does not have export_graphviz function.

Instead what you might want to do is generate one image per tree in the forest by iterating over the elements of the rfr_1.estimators_ list and calling the export_graphviz method of the sklearn.tree package on each of those tree.

However in practice displaying the trees of a forest is very often useless. Practitioners typically build random forests with hundreds or thousands of trees to get a good predictive accuracy. In such cases, visually inspecting that many trees is impractical.

like image 163
ogrisel Avatar answered Sep 28 '22 08:09

ogrisel