Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to stop training when it hits a specific validation accuracy?

I am training a convolutional network and I want to stop training once the validation error hits 90%. I thought about using EarlyStopping and setting baseline to .90 but then it stops training whenever the validation accuracy is below that baseline for given number of epochs(which is just 0 here). So my code is:

es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])

When I use this code my training stops after the first epoch with given results:

Train on 60000 samples, validate on 10000 samples

Epoch 1/30 60000/60000 - 7s - loss: 0.4600 - acc: 0.8330 - val_loss: 0.3426 - val_acc: 0.8787

What else can I try to stop my training once the validation accuracy hits 90% or above?

Here is the rest of the code:

  tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(152, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer=Adam(learning_rate=0.001),loss='sparse_categorical_crossentropy', metrics=['accuracy'])
es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])

Thank you!

like image 390
glslmn Avatar asked Jan 25 '23 11:01

glslmn


2 Answers

Early Stopping Callback will search for a value that stopped increasing (or decreasing) so it's not a good use for your problem. However tf.keras allows you to use custom callbacks.

For your example:

class MyThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(MyThresholdCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None): 
        val_acc = logs["val_acc"]
        if val_acc >= self.threshold:
            self.model.stop_training = True

For TF version 2.3 or above, you might have to use "val_accuracy" instead of "val_acc". Thank you Christian Westbrook for the note in the comments.

The above Callback, on each epoch end, will extract Validation Accuracy from all available logs. Then it will compare it with user defined threshold (in your case 90%). If the criterion is met the training will be stopped.

With that you can simply call:

my_callback = MyThresholdCallback(threshold=0.9)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2, callbacks=[my_callback])

Alternatively, you can use def on_batch_end(...) if you want to stop immediately. This however, requires parameters batch, logs instead of epoch, logs.

like image 163
sebastian-sz Avatar answered Jan 28 '23 00:01

sebastian-sz


The existing answer looks fine but I've used a shorter version in the past:

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs.get('accuracy') >= 9e-1:
            self.model.stop_training = True

You can implement it like this:

callback = CustomCallback()

history = model.fit(..., callbacks=[callback])
like image 24
Nicolas Gervais Avatar answered Jan 28 '23 01:01

Nicolas Gervais