Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I adjust the size of the plot_tree graph in sklearn to make it readable?

I am trying to plot a plot_tree object from sklearn with matplotlib, but my tree plot doesn't look good. My tree plot looks squished:

enter image description here

Below are my code:

from sklearn import tree
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# create tree object 
model_gini_class = tree.DecisionTreeClassifier(criterion='gini')

# train the model using the training sets and check score
model_gini_class.fit(X_train, y_train)
model_gini_class.score(X_train, y_train)

# predict output
predicted_gini_class = model_gini_class.predict(X_test)

plt.figure()
tree.plot_tree(model_gini_class, filled=True)
plt.title("Decision trees on the Shakespear dataset (Gini)")
plt.show() # the tree looks squished?

So my questions is:

  • Could someone tell me how to adjust the size of the sklearn plot_tree object so it doesn't look squished?

Thank you,

like image 504
chico0913 Avatar asked Jan 26 '23 13:01

chico0913


2 Answers

this might help

plt.figure(figsize=(10,10))
like image 193
Ozgur Avatar answered Jan 30 '23 01:01

Ozgur


You can do two things:

Method 1


# Decision tree
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)


_, ax = plt.subplots(figsize=(30,30)) # Resize figure
plot_tree(classifier, filled=True, ax=ax)
plt.show()

Method 2


# Decision tree
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)

plt.figure(figsize=(30, 30) # Resize figure
plot_tree(classifier, filled=True)
plt.show()

Whatever you prefer using

like image 26
Antoine Krajnc Avatar answered Jan 30 '23 00:01

Antoine Krajnc