Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Early stopping with multiple conditions

I am doing multi-class classification for a recommender system (item recommendations), and I'm currently training my network using sparse_categorical_crossentropy loss. Therefore, it is reasonable to perform EarlyStopping by monitoring my validation loss, val_loss as such:

tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)

which works as expected. However, the performance of the network (recommender system) is measured by Average-Precision-at-10, and is tracked as a metric during training, as average_precision_at_k10. Because of this, I could also perform early stopping with this metric as such:

tf.keras.callbacks.EarlyStopping(monitor='average_precision_at_k10', patience=10)

which also works as expected.

My problem: Sometimes the validation loss increases, whilst the Average-Precision-at-10 is improving and vice-versa. Because of this, I would need to monitor both, and perform early stopping, if and only if both are deteriorating. What I would like to do:

tf.keras.callbacks.EarlyStopping(monitor=['val_loss', 'average_precision_at_k10'], patience=10)

which obviously does not work. Any ideas how this could be done?

like image 629
Marcus Avatar asked Oct 27 '20 14:10

Marcus


2 Answers

With guidance from Gerry P above I managed to create my own custom EarlyStopping callback, and thought I post it here in case anyone else are looking to implement something similar.

If both the validation loss and the mean average precision at 10 does not improve for patience number of epochs, early stopping is performed.

class CustomEarlyStopping(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(CustomEarlyStopping, self).__init__()
        self.patience = patience
        self.best_weights = None
        
    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best_v_loss = np.Inf
        self.best_map10 = 0

    def on_epoch_end(self, epoch, logs=None): 
        v_loss=logs.get('val_loss')
        map10=logs.get('val_average_precision_at_k10')

        # If BOTH the validation loss AND map10 does not improve for 'patience' epochs, stop training early.
        if np.less(v_loss, self.best_v_loss) and np.greater(map10, self.best_map10):
            self.best_v_loss = v_loss
            self.best_map10 = map10
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)
                
    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

It is then used as:

model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=30,
    verbose=0,
    callbacks=[CustomEarlyStopping(patience=10)],
)
like image 59
Marcus Avatar answered Oct 19 '22 19:10

Marcus


You can achieve this by by creating a custom callback. Information on how to do that is located here. Below is some code that illustrates what you can do in a custom callback. The documentation I referenced shows many other options.

class LRA(keras.callbacks.Callback): # subclass the callback class
# create class variables as below. These can be accessed in your code outside the class definition as LRA.my_class_variable, LRA.best_weights
    my_class_variable=something  # a class variable
    best_weights=model.get_weights() # another  class variable
# define an initialization function with parameters you want to feed to the callback
    def __init__(self, param1, param2, etc):
        super(LRA, self).__init__()
        self.param1=param1
        self.param2=param2
        etc for all parameters
        # write any initialization code you need here

    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch
        v_loss=logs.get('val_loss')  # example of getting log data at end of epoch the validation loss for this epoch
        acc=logs.get('accuracy') # another example of getting log data 
        LRA.best_weights=model.get_weights() # example of setting class variable value
        print(f'Hello epoch {epoch} has just ended') # print a message at the end of every epoch
        lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
        if v_loss > self.param1:
           new_lr=lr * self.param2
           tf.keras.backend.set_value(model.optimizer.lr, new_lr) # set the learning rate in the optimizer
        # write whatever code you need
like image 34
Gerry P Avatar answered Oct 19 '22 19:10

Gerry P