Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I get the trained model from xgboost CV?

Tags:

python

xgboost

I am running the following code:

params = {"objective":"reg:squarederror",'colsample_bytree': 0.3,'learning_rate': 0.15,
                'max_depth': 5, 'alpha': 15}

data_dmatrix = xgb.DMatrix(data=X_train,label=y_train)
cv_results = xgb.cv(dtrain=data_dmatrix, params=params, nfold=3,
                    num_boost_round=50, early_stopping_rounds=10, 
                    metrics="rmse", as_pandas=True, seed=0)

The result looks great and I would like to test the best model from the cross validation with my data I held back. But how can I get the model?

like image 373
cbueltem Avatar asked Sep 05 '25 03:09

cbueltem


2 Answers

XGBoost API provides the callbacks mechanism. Callbacks allow you to call custom function before and after every epoch, before and after training.

Since you need get final models after cv, we can define such callback:

class SaveBestModel(xgb.callback.TrainingCallback):
    def __init__(self, cvboosters):
        self._cvboosters = cvboosters
    
    def after_training(self, model):
        self._cvboosters[:] = [cvpack.bst for cvpack in model.cvfolds]
        return model

In case of xgb.cv the argument model in method after_training is an instance of xgb.training._PackedBooster. Now we should pass callback to xgb.cv.

cvboosters = []

cv_results = xgb.cv(dtrain=data_dmatrix, params=params, nfold=3,
                    num_boost_round=50, early_stopping_rounds=10, 
                    metrics="rmse", as_pandas=True, seed=0,
                    callbacks=[SaveBestModel(cvboosters), ])

Your models will be saved in cvboosters.

like image 73
Vladimir Avatar answered Sep 07 '25 21:09

Vladimir


Unlike, say, scikit-learn GridSearchCV, which returns a model (optionally refitted with the whole data if called with refit=True), xgb.cv does not return any model, only the evaluation history; from the docs:

Returns evaluation history

In this sense, it is similar to scikit-learn's cross_validate, which also does not return any model - only the metrics.

So, provided that you are happy with the CV results and you want to proceed to fit the model with all the data, you must do it separately:

bst = xgb.train(dtrain=data_dmatrix, params=params, num_boost_round=50)
like image 35
desertnaut Avatar answered Sep 07 '25 22:09

desertnaut