I am trying to plot a plot_tree
object from sklearn
with matplotlib
, but my tree plot doesn't look good. My tree plot looks squished:
Below are my code:
from sklearn import tree
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# create tree object
model_gini_class = tree.DecisionTreeClassifier(criterion='gini')
# train the model using the training sets and check score
model_gini_class.fit(X_train, y_train)
model_gini_class.score(X_train, y_train)
# predict output
predicted_gini_class = model_gini_class.predict(X_test)
plt.figure()
tree.plot_tree(model_gini_class, filled=True)
plt.title("Decision trees on the Shakespear dataset (Gini)")
plt.show() # the tree looks squished?
So my questions is:
sklearn
plot_tree object so it doesn't look squished? Thank you,
this might help
plt.figure(figsize=(10,10))
You can do two things:
# Decision tree
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)
_, ax = plt.subplots(figsize=(30,30)) # Resize figure
plot_tree(classifier, filled=True, ax=ax)
plt.show()
# Decision tree
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)
plt.figure(figsize=(30, 30) # Resize figure
plot_tree(classifier, filled=True)
plt.show()
Whatever you prefer using
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With