Is there a way to print a trained decision tree in scikit-learn? I want to train a decision tree for my thesis and I want to put the picture of the tree in the thesis. Is that possible?
All decision trees use np. float32 arrays internally. If training data is not in this format, a copy of the dataset will be made.
There is a method to export to graph_viz format: http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
So from the online docs:
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
... out_file='tree.dot')
Then you can load this using graph viz, or if you have pydot installed then you can do this more directly: http://scikit-learn.org/stable/modules/tree.html
>>> from sklearn.externals.six import StringIO
>>> import pydot
>>> dot_data = StringIO()
>>> tree.export_graphviz(clf, out_file=dot_data)
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue())
>>> graph.write_pdf("iris.pdf")
Will produce an svg, can't display it here so you'll have to follow the link: http://scikit-learn.org/stable/_images/iris.svg
Update
It seems that there has been a change in the behaviour since I first answered this question and it now returns a list
and hence you get this error:
AttributeError: 'list' object has no attribute 'write_pdf'
Firstly when you see this it's worth just printing the object and inspecting the object, and most likely what you want is the first object:
graph[0].write_pdf("iris.pdf")
Thanks to @NickBraunagel for the comment
Although I'm late to the game, the below comprehensive instructions could be useful for others who want to display decision tree output:
Install necessary modules:
graphviz
. I used conda's install package here
(recommended over pip install graphviz
as pip
install doesn't
include the actual GraphViz executables)pydot
via pip (pip install pydot
)graph
is a list
containing the pydot.Dot
object):from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph[0].write_pdf("iris.pdf") # must access graph's first element
Now you'll find the "iris.pdf" within your environment's default directory
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With