This questions is very specific to the Python library scikit-learn. Please let me know if it's a better idea to post it somewhere else. Thanks!
Now the question...
I have a feed-forward neural network class ffnn based on BaseEstimator which I train with SGD. It's working fine, and I can also train it in parallel using GridSearchCV().
Now I want to implement early stopping in the function ffnn.fit() but for this I also need access to the validation data of the fold. One way of doing this is to change the line in sklearn.grid_search.fit_grid_point() which says
clf.fit(X_train, y_train, **fit_params)
into something like
clf.fit(X_train, y_train, X_test, y_test, **fit_params)
and also change ffnn.fit() to take these arguments. This would also affect other classifiers in sklearn, which is a problem. I can avoid this by checking for some kind of a flag in fit_grid_point() which tells me when to call clf.fit() in either of the above two ways.
Can someone suggest a different way to do this where I don't have to edit any code in the sklearn library?
Alternatively, would it be right to further split X_train and y_train into train/validation sets randomly and check for a good stopping point, then re-train the model on all of X_train?
Thanks!
You could just make you neural network model internally extract a validation set from the passed X_train
and y_train
by using the train_test_split
function for instance.
Edit:
Alternatively, would it be right to further split X_train and y_train into train/validation sets randomly and check for a good stopping point, then re-train the model on all of X_train?
Yes but that would be expensive. You could just find the stopping point and then just a do a single additional pass over the validation data that you used to find the stopping point.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With