Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

displaying scikit decision tree figure in jupyter notebook

I am currently creating a machine learning jupyter notebook as a small project and wanted to display my decision trees. However, all options I can find are to export the graphics and then load a picture, which is rather complicated.

Therefore, I wanted to ask whether there is a way to display my decision trees directly without exporting and loading graphics.

like image 876
Jürgen Erhardt Avatar asked Jan 25 '19 17:01

Jürgen Erhardt


2 Answers

As of scikit-learn version 21.0 (roughly May 2019), Decision Trees can now be plotted with matplotlib using scikit-learn’s tree.plot_tree without relying on graphviz.

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

X, y = load_iris(return_X_y=True)

# Make an instance of the Model
clf = DecisionTreeClassifier(max_depth = 5)

# Train the model on the data
clf.fit(X, y)

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']

# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

tree.plot_tree(clf,
           feature_names = fn, 
           class_names=cn,
           filled = True);

# You can save your plot if you want
#fig.savefig('imagename.png')

Something similar to what is below will output in your jupyter notebook.
enter image description here

The code was adapted from this post.

like image 141
Michael James Kali Galarnyk Avatar answered Sep 20 '22 04:09

Michael James Kali Galarnyk


You can show the tree directly using IPython.display:

import graphviz
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier,export_graphviz
from sklearn.datasets import make_regression

# Generate a simple dataset
X, y = make_regression(n_features=2, n_informative=2, random_state=0)
clf = DecisionTreeRegressor(random_state=0, max_depth=2)
clf.fit(X, y)
# Visualize the tree
from IPython.display import display
display(graphviz.Source(export_graphviz(clf)))
like image 22
hellpanderr Avatar answered Sep 20 '22 04:09

hellpanderr