Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Early stopping with Keras and sklearn GridSearchCV cross-validation

I wish to implement early stopping with Keras and sklean's GridSearchCV.

The working code example below is modified from How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras. The data set may be downloaded from here.

The modification adds the Keras EarlyStopping callback class to prevent over-fitting. For this to be effective it requires the monitor='val_acc' argument for monitoring validation accuracy. For val_acc to be available KerasClassifier requires the validation_split=0.1 to generate validation accuracy, else EarlyStopping raises RuntimeWarning: Early stopping requires val_acc available!. Note the FIXME: code comment!

Note we could replace val_acc by val_loss!

Question: How can I use the cross-validation data set generated by the GridSearchCV k-fold algorithm instead of wasting 10% of the training data for an early stopping validation set?

# Use scikit-learn to grid search the learning rate and momentum
import numpy
from sklearn.model_selection import GridSearchCV
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.optimizers import SGD

# Function to create model, required for KerasClassifier
def create_model(learn_rate=0.01, momentum=0):
    # create model
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    # Compile model
    optimizer = SGD(lr=learn_rate, momentum=momentum)
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

# Early stopping
from keras.callbacks import EarlyStopping
stopper = EarlyStopping(monitor='val_acc', patience=3, verbose=1)

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = KerasClassifier(
    build_fn=create_model,
    epochs=100, batch_size=10,
    validation_split=0.1, # FIXME: Instead use GridSearchCV k-fold validation data.
    verbose=2)
# define the grid search parameters
learn_rate = [0.01, 0.1]
momentum = [0.2, 0.4]
param_grid = dict(learn_rate=learn_rate, momentum=momentum)
grid = GridSearchCV(estimator=model, param_grid=param_grid, verbose=2, n_jobs=1)

# Fitting parameters
fit_params = dict(callbacks=[stopper])
# Grid search.
grid_result = grid.fit(X, Y, **fit_params)

# summarize results
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))
like image 420
Justin Solms Avatar asked Jan 06 '18 12:01

Justin Solms


1 Answers

[Old answer, before the question was edited & clarified - see updated & accepted answer above]

I am not sure I have understood your exact issue (your question is quite unclear, and you include many unrelated details, which is never good when asking a SO question - see here).

You don't have to (and actually should not) include any arguments about validation data in your model = KerasClassifier() function call (it is interesting why you don't feel the same need for training data here, too). Your grid.fit() will take care of both the training and validation folds. So provided that you want to keep the hyperparameter values as included in your example, this function call should be simply

model = KerasClassifier(build_fn=create_model, 
                        epochs=100, batch_size=32,
                        shuffle=True,
                        verbose=1)

You can see some clear and well-explained examples regarding the use of GridSearchCV with Keras here.

like image 151
desertnaut Avatar answered Oct 20 '22 20:10

desertnaut