Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Visualising the decision tree in sklearn

When I want to visualise the tree I got this error.

I have shown the required libraries imported. Is there expected reason with jupiter-notebook ?

from sklearn import tree
import matplotlib.pyplot
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
cancer=load_breast_cancer()
x=cancer.data
y=cancer.target
clf=DecisionTreeClassifier(max_depth=1000)
x_train,x_test,y_train,y_test=train_test_split(x,y)
clf=clf.fit(x_train,y_train)
tree.plot_tree(clf.fit(x_train,y_train))

AttributeError: module 'sklearn.tree' has no attribute 'plot_tree'

like image 202
Roshan Avatar asked Sep 06 '19 07:09

Roshan


2 Answers

I assigned the tree to an object and added plt.show(). This works for me.

%matplotlib inline
from sklearn import tree
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
x = cancer.data
y = cancer.target
clf = DecisionTreeClassifier(max_depth = 1000)
x_train,x_test,y_train,y_test = train_test_split(x,y)

fig = clf.fit(x_train,y_train)
tree.plot_tree(fig)
plt.show()

But I recommend using graphviz, it's much more flexible.

like image 123
Anna Yashina Avatar answered Sep 30 '22 07:09

Anna Yashina


upgrade sklearn package:

pip install --upgrade sklearn
like image 36
Mohammad Nazari Avatar answered Sep 30 '22 08:09

Mohammad Nazari