Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

LightGBM: continue training a model

Tags:

lightgbm

I am training a model using cross validation like so:

classifier = lgb.Booster(
    params=params, 
    train_set=lgb_train_set,
)

result = lgb.cv(
    init_model=classifier,
    params=params, 
    train_set=lgb_train_set,
    num_boost_round=1000,
    early_stopping_rounds=20,
    verbose_eval=50,
    shuffle=True
)

I would like to continue training the model by running the second command multiple times (maybe with a new training set or with different parameters) and it would continue improving the model.

However, when I try this it is clear that the model is starting from scratch each time.

Is there a different approach to do what I am intending?

like image 385
William Entriken Avatar asked Aug 12 '17 21:08

William Entriken


1 Answers

Can be solved using init_model option of lightgbm.train, which accepts one of two objects

  1. a filename of LightGBM model, or
  2. a lightgbm Booster object

Code illustration:

import numpy as np 
import lightgbm as lgb

data = np.random.rand(1000, 10) # 1000 entities, each contains 10 features
label = np.random.randint(2, size=1000) # binary target
train_data = lgb.Dataset(data, label=label, free_raw_data=False)
params = {}

#Initialize with 10 iterations
gbm_init = lgb.train(params, train_data, num_boost_round = 10)
print("Initial iter# %d" %gbm_init.current_iteration())

# Example of option #1 (pass a file):
gbm_init.save_model('model.txt')
gbm = lgb.train(params, train_data, num_boost_round = 10,
                init_model='model.txt')
print("Option 1 current iter# %d" %gbm.current_iteration())


# Example of option #2 (pass a lightgbm Booster object):
gbm_2 = lgb.train(params, train_data, num_boost_round = 10,
                init_model = gbm_init)
print("Option 2 current iter# %d" %gbm_2.current_iteration())

https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.train.html

like image 125
Tarek Oraby Avatar answered Sep 17 '22 17:09

Tarek Oraby