Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

xgboost watchlist parameter: DMatrix object is not iterable

I am trying to train a xgboost model on a dataset X-train, X_test. Code:

xgb_params = {
    "objective": "multi:softmax",
    "eta": 0.3,
    "num_class": 62,
    "max_depth": 10,
    "nthread": 4,
    "eval_metric": "merror",
    "print.every.n": 1
    #"silent": 1
}
num_rounds = 2

mask = np.random.choice([False, True], len(X_train), p=[0.75, 0.25])
not_mask = [not i for i in mask]

dtrain = xgb.DMatrix(X_train[not_mask], label=y[not_mask])
dtrain_watch = xgb.DMatrix(X_train[mask], label=y[mask])
dtest = xgb.DMatrix(X_test)

gbdt = xgb.train(xgb_params, dtrain, num_rounds, dtrain_watch)
preds = gbdt.predict(dtest)

But I am getting error: TypeError: 'DMatrix' object is not iterable on the second last line.

What needs to be done?

EDIT: If I typecast to np.array deliberately, I get TypeError: iteration over a 0-d array

EDIT 2: If I avoid DMatrix fully this way: watchlist = list(np.append(X_train[mask], np.matrix(y[mask]).transpose(), axis=1)) I get another error

xgboost.py", line 340, in __init__
    raise TypeError('invalid cache item: {}'.format(type(d).__name__))
TypeError: invalid cache item: matrix
like image 297
sinhayash Avatar asked Mar 15 '23 21:03

sinhayash


2 Answers

According to the documentation, the monitoring datasets should be passed in the following form:

evallist  = [(dtest,'eval'), (dtrain,'train')]

So, your code should be changed to something like this:

evallist  = [(dtrain_watch, 'eval')]
gbdt = xgb.train(xgb_params, dtrain, num_rounds, evallist)

Also, you may want to use

not_mask = ~mask

instead so that the resulting mask is a Boolean numpy array instead of a list of bools. Otherwise, the slicing may not work as you expected although there's no error (check the dimension of the training and validation data in your original code).

like image 165
YS-L Avatar answered Apr 30 '23 20:04

YS-L


watchlist = [(dtest, 'test'), (dtrain, 'train')]
bst = xgb.train(params, dtrain, num_rounds, watchlist)

This worked very fine for me.

like image 31
Yonela Nuba Avatar answered Apr 30 '23 21:04

Yonela Nuba