Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

python xgboost continue training on existing model

Tags:

python

xgboost

Lets say I build an xgboost model:

bst = xgb.train(param0, dtrain1, num_round, evals=[(dtrain, "training")])

Where:

  • param0 is a set of params to xgb,
  • dtrain1 is a DMatrix ready to be trained
  • num_round is the number of rounds

Then, I save the model to disk:

bst.save_model("xgbmodel")

Later on, I want to reload the model I saved and continue training it with dtrain2

Does anyone have an idea how to do it?

like image 933
Eran Moshe Avatar asked Oct 29 '17 12:10

Eran Moshe


Video Answer


2 Answers

You don't even have to load the model from the disk and retrain.

All you need to do is the same xgb.train command with additional parameter: xgb_model= (either xgboost model full path name you've saved like in the question or a Booster object).

Example:

bst = xgb.train(param0, dtrain2, num_round, evals=[(dtrain, "training")], xgb_model='xgbmodel')

Good luck!

like image 134
Eran Moshe Avatar answered Sep 17 '22 14:09

Eran Moshe


For users who are looking to continue training with XGBClassifier or object obtained from .fit function of sklearn.

from xgboost import XGBClassifier

# best_est = best number of tree
# best_lr = best learning days
# best_subsample = best subsample bw 0 and 1

params = {'objective': 'binary:logistic', 'use_label_encoder': False, 
          'seed': 27, 'eval_metric': 'logloss', 'n_estimators': best_est, 
          'learning_rate': best_lr, 'subsample': best_subsample}

# train iteration 1  below

model = XGBClassifier(**params)
model.fit(x_train_1, y_train_1)

# train iteration 2 below

model = model.fit(x_train_2, y_train_2, xgb_model=model.get_booster())

In the above code x_train_*, y_train_* are the object of pandas DataFrame type.

The main concept to learn here is, xgb core functions while retraining always takes the booster as input. So one can either provide the booster from model object or the saved model path.

like image 30
MSS Avatar answered Sep 18 '22 14:09

MSS