Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot an Extreme Gradient Boosting tree built with caret

I just built a basic classification model with package caret using the "xgbTree" method (Extreme Gradient Boosting). It has great accuracy (3 classes) but I can't see the rules or plot the tree.

Does any one know how to plot the tree, when it is built with caret? I tried using thexgb.plot.tree function from the xgboost package, but I get an error indicating it can't be plotted because my model is not an object of class xgb.Booster generated by the xgb.train function. Is there a way I can coerce the model I built in caret into an xgb.Booster object?

I appreciate any help.

like image 311
HunkyGoon Avatar asked Oct 18 '22 13:10

HunkyGoon


1 Answers

I was having the same issue and when I looked into ? help for xgb.plot.tree the first parameter is a vector of feature names and then you must specific the model parameter.

xgb.plot.tree(model = myegb$finalModel)

The above will produce the tree diagram without feature names; will use column index.

To add feature names:

xgb.plot.tree(feature_names = myegb$finalModel$params, model = myegb$finalModel)
like image 175
Matt L. Avatar answered Oct 20 '22 22:10

Matt L.