Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Visualizing a decision tree from a sklearn random forest regressor

hi I have a random forest called rf.

The documentation , tells me that rf.estimators gives a list of the trees. I am interested in visualizing one, or if I can't at least find out how many nodes the tree has.

my intuition was that the plot_tree function, shown here would be able to be used on the tree, but when i run

rf.estimators_[0].plot_tree()

I get

AttributeError: 'DecisionTreeRegressor' object has no attribute 'plot_tree'

like image 458
Abijah Avatar asked Feb 11 '26 12:02

Abijah


1 Answers

Import tree from Sklearn and pass the desired estimator to the plot_tree function.

Setup:

from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_features=4, n_informative=2,
                       random_state=0, shuffle=False)
regr = RandomForestRegressor(max_depth=2, random_state=0)
regr.fit(X, y)

print(regr.predict([[0, 0, 0, 0]]))
#[-8.32987858]

Use plot_tree

from sklearn import tree
tree.plot_tree(regr.estimators_[0])

enter image description here

like image 87
n1colas.m Avatar answered Feb 13 '26 09:02

n1colas.m



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!