Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Early-stopping while training neural network in scikit-learn

This questions is very specific to the Python library scikit-learn. Please let me know if it's a better idea to post it somewhere else. Thanks!

Now the question...

I have a feed-forward neural network class ffnn based on BaseEstimator which I train with SGD. It's working fine, and I can also train it in parallel using GridSearchCV().

Now I want to implement early stopping in the function ffnn.fit() but for this I also need access to the validation data of the fold. One way of doing this is to change the line in sklearn.grid_search.fit_grid_point() which says

clf.fit(X_train, y_train, **fit_params)

into something like

clf.fit(X_train, y_train, X_test, y_test, **fit_params)

and also change ffnn.fit() to take these arguments. This would also affect other classifiers in sklearn, which is a problem. I can avoid this by checking for some kind of a flag in fit_grid_point() which tells me when to call clf.fit() in either of the above two ways.

Can someone suggest a different way to do this where I don't have to edit any code in the sklearn library?

Alternatively, would it be right to further split X_train and y_train into train/validation sets randomly and check for a good stopping point, then re-train the model on all of X_train?

Thanks!

like image 588
user1953384 Avatar asked Feb 14 '23 17:02

user1953384


1 Answers

You could just make you neural network model internally extract a validation set from the passed X_train and y_train by using the train_test_split function for instance.

Edit:

Alternatively, would it be right to further split X_train and y_train into train/validation sets randomly and check for a good stopping point, then re-train the model on all of X_train?

Yes but that would be expensive. You could just find the stopping point and then just a do a single additional pass over the validation data that you used to find the stopping point.

like image 100
ogrisel Avatar answered Feb 22 '23 01:02

ogrisel