I'm trying to run this machine learning tree algorithm code in IPython:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
from sklearn.tree import export_graphviz
export_graphviz(tree_clf, out_file=image_path("iris_tree.dot"),
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)
But I get this error when run in IPython:
I'm unfamiliar with export_graphviz, does anyone have any idea how to correct this?
I know you might have got what you were looking for. But in case you don't, all you need to do is just replace:
out_file=image_path("iris_tree.dot")
with:
out_file="iris_tree.dot"
This will create the .dot file in the same directory in which your current script is. You can also give the absolute path to where you want to save the .dot file as:
out_file="/home/cipher/iris_tree.dot"
To get rid of all the mess simply remove image_path
,
now out_file="iris_tree.dot"
, after running that command a file will be saved in your folder named iris_tree
. Open that file in Microsoft Word and copy all of its content. Now open your browser and type "webgraphviz" and then click on the first link. Then delete whatever is written in white space and paste your code which is copied from iris_tree
. Then click "generate graph". Scroll down and your graph is ready.
I guess you are following "Hands on Machine Learning with Scikit-Learn and TensorFlow" book by Aurelien Geron. I encountered with the same problem while trying out "Decision Trees" chapter. You can always refer to his GitHub notebooks . For your code, you may refer "decision tree" notebook. Below I paste the code from notebook. Please do go ahead and have a look at the notebook also.
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals
# Common imports
import numpy as np
import os
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "decision_trees"
def image_path(fig_id):
return os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id)
def save_fig(fig_id, tight_layout=True):
print("Saving figure", fig_id)
if tight_layout:
plt.tight_layout()
plt.savefig(image_path(fig_id) + ".png", format='png', dpi=300)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With