Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use the early_stopping_rounds parameter in XGBooost

I'm performing a tutorial at Datacamp Extreme Gradient Boosting with XGBoost and i'm a bit confuse about one result.

When performing the code below

# Create your housing DMatrix: 
housing_dmatrix = xgb.DMatrix(data=data, label=y)

# Create the parameter dictionary for each tree: params
params = {"objective":"reg:linear", "max_depth":4}

# Perform cross-validation with early stopping: cv_results
cv_results = xgb.cv(dtrain=housing_dmatrix,params=params,nfold=3, num_boost_round=50,  early_stopping_rounds=10, metrics="rmse", as_pandas=True, seed=123)

# Print cv_results
print(cv_results)

mean_mae = cv_results['test-rmse-mean'].min()
boost_rounds = cv_results['test-rmse-mean'].idxmin()
print("\tRMSE {} for {} rounds".format(mean_mae, boost_rounds))

I get this output:

    test-rmse-mean  test-rmse-std  train-rmse-mean  train-rmse-std
0    142644.104167     705.732300    141861.109375      396.179855
1    104867.638021     109.049658    103035.130209       47.104957
2     79261.453125     573.946485     75944.486979      264.292182
3     61684.972656    1010.777141     57339.683594      525.791040
4     50591.589844    1796.898432     44487.089844      597.177883
5     43296.019531    1826.395300     35787.472656      722.298168
6     38898.184896    1751.704546     29919.669922      691.136345
7     36395.877604    1678.547627     26036.014974      643.247157
8     34805.308594    1738.570391     23456.447266      735.406007
9     33709.807292    1476.484849     21628.860677      744.936771
10    33272.565104    1321.658776     20267.642578      720.847179
11    32559.872396    1067.106826     19335.697917      583.088133
12    32151.852865     988.789741     18699.406901      577.292625
13    32147.481120     727.223369     18167.764974      552.016786
14    31866.852865     816.375161     17630.004557      540.065450
15    31786.338542     719.051429     17277.742187      521.709410
16    31541.460287     747.203840     16895.608724      416.004588
17    31397.618490     754.207664     16581.137370      410.450468
18    31445.539714     822.826225     16352.770508      469.756961
19    31452.218099     876.353819     16117.928060      450.584992
20    31385.645182     829.827774     15862.508464      384.789582
21    31296.314453     881.374220     15542.022461      410.365996
22    31290.126953     875.634123     15261.000325      382.412151
23    31252.852865     782.195943     15008.689128      490.904289
24    31210.020182     753.773505     14853.276042      531.599930
    RMSE 31210.020182333334 for 24 rounds

By changing thne num_boost_round from 10 to 20 i get this new output

    test-rmse-mean  test-rmse-std  train-rmse-mean  train-rmse-std
0    142644.104167     705.732300    141861.109375      396.179855
1    104867.638021     109.049658    103035.130209       47.104957
2     79261.453125     573.946485     75944.486979      264.292182
3     61684.972656    1010.777141     57339.683594      525.791040
4     50591.589844    1796.898432     44487.089844      597.177883
5     43296.019531    1826.395300     35787.472656      722.298168
6     38898.184896    1751.704546     29919.669922      691.136345
7     36395.877604    1678.547627     26036.014974      643.247157
8     34805.308594    1738.570391     23456.447266      735.406007
9     33709.807292    1476.484849     21628.860677      744.936771
10    33272.565104    1321.658776     20267.642578      720.847179
11    32559.872396    1067.106826     19335.697917      583.088133
12    32151.852865     988.789741     18699.406901      577.292625
13    32147.481120     727.223369     18167.764974      552.016786
14    31866.852865     816.375161     17630.004557      540.065450
15    31786.338542     719.051429     17277.742187      521.709410
16    31541.460287     747.203840     16895.608724      416.004588
17    31397.618490     754.207664     16581.137370      410.450468
18    31445.539714     822.826225     16352.770508      469.756961
19    31452.218099     876.353819     16117.928060      450.584992
20    31385.645182     829.827774     15862.508464      384.789582
21    31296.314453     881.374220     15542.022461      410.365996
22    31290.126953     875.634123     15261.000325      382.412151
23    31252.852865     782.195943     15008.689128      490.904289
24    31210.020182     753.773505     14853.276042      531.599930
25    31351.985677     675.589696     14622.553711      484.633075
26    31355.926432     698.751442     14430.575195      619.787710
27    31389.998698     648.971218     14204.662760      568.655261
28    31356.138021     711.192347     14013.201498      546.941784
29    31292.576172     697.337663     13823.481445      581.756303
30    31306.197266     751.915534     13579.544271      599.841587
31    31285.968099     724.867998     13361.774739      631.805694
32    31299.003255     684.342196     13251.625325      674.017166
33    31287.397786     720.993414     13034.457031      749.805941
34    31249.546224     705.225396     12844.346354      691.202512
35    31203.748698     733.094000     12618.286133      624.664883
36    31173.789063     733.643603     12495.182292      609.643329
37    31172.117839     669.111951     12387.871745      607.720998
38    31170.850912     663.908881     12240.325846      583.204898
39    31162.634765     687.274839     12082.383463      555.608639
40    31142.460287     712.869918     11930.459310      589.554611
41    31160.911459     709.512382     11804.457357      614.642441
42    31175.446615     652.927064     11729.721680      618.773653
43    31166.930339     642.093634     11624.710612      623.662847
44    31135.966797     650.693105     11415.508138      612.949992
45    31155.915365     677.580853     11300.589193      651.227389
46    31158.777344     622.338647     11231.110026      663.325455
47    31161.392578     646.070369     11111.234049      679.541984
48    31153.126302     623.746479     11038.423828      677.875949
49    31137.419922     628.337573     10952.264323      653.939810
    RMSE 31135.966797 for 44 rounds
  1. My understanding is that for the second output, all 50 steps are performed and then the minimum value is retrieved. Is this correct?
  2. Furthermore, my understanding is that the first output is identical to the second except that it stops after 24 rounds. Is this correct?

Thanks for all.

like image 404
Thierry K. Avatar asked Dec 17 '17 10:12

Thierry K.


1 Answers

  1. You are correct. Because you set rmse as your metric and did not set maximize = TRUE, XGBoost will return the round with the lowest RMSE within the allotted rounds.

  2. This is also correct. If you set early_stopping_rounds = n, XGBoost will halt before reaching num_boost_round if it has gone n rounds without an improvement in the metric.

Please consider including a sample data set so that this example is reproducible and therefore more useful to future readers.

like image 149
data princess Avatar answered Nov 01 '22 17:11

data princess