Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Xgboost plot_tree Error: ValueError: booster must be Booster instance

I am new to xgboost and i want to visualize my xgboost model.

here is my code, the code comes from a tutorial and probably have no bug.

from numpy import loadtxt
from xgboost import XGBClassifier
from xgboost import plot_tree
import matplotlib.pyplot as plt

dataset = loadtxt('pima-indians-diabetes.csv', delimiter=",")
X = dataset[:,0:8]
y = dataset[:,8]
model = XGBClassifier()
model.fit(X, y)
plot_tree(model)
plt.show()

I use UBuntu and i have installed graphviz, run this code will get

Traceback (most recent call last):
File "a.py", line 15, in <module>
    plot_tree(model)
  File "/home/statham/anaconda2/lib/python2.7/site-packages/xgboost/plotting.py", line 214, in plot_tree
    g = to_graphviz(booster, num_trees=num_trees, rankdir=rankdir, **kwargs)
  File "/home/statham/anaconda2/lib/python2.7/site-packages/xgboost/plotting.py", line 160, in to_graphviz
    raise ValueError('booster must be Booster instance')
ValueError: booster must be Booster instance

I know the key point is my model is not a Booster instance, I have searched Google and I didn't find an asnwer, can anyone tell me how to convert my model into a Booster instance? Thanks in advance.

like image 710
Statham Avatar asked Aug 28 '17 13:08

Statham


1 Answers

I find the answer.

Just change

plot_tree(model)

into:

plot_tree(model._Booster)

and it will work.

like image 64
Statham Avatar answered Nov 11 '22 13:11

Statham