Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

logging learning rate schedule in keras via weights and biases

I am training a keras model and using a custom learning rate scheduler for the optimizer (of type tf.keras.optimizers.schedules.LearningRateSchedule), and i want to log the learning rate change via the weights&biases framework. i couldn't find how to pass it to the WandbCallback object or log it in any way

like image 236
Matan Halfon Avatar asked Mar 01 '23 11:03

Matan Halfon


2 Answers

Updated based on Martjin's comment!

you can log custom learning rate onto Weights and Biases using a custom Keras callback.

W&B's WandbCallback cannot automatically log your custom learning rate. Usually, for such custom logging, if you are using a custom training loop you can use wandb.log(). If you are using model.fit() custom Keras callback is the way.

For example:

This is my tf.keras.optimizers.schedules.LearningRateSchedule based scheduler.

class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

  def __init__(self, initial_learning_rate):
    self.initial_learning_rate = initial_learning_rate

  def __call__(self, step):
     return self.initial_learning_rate / (step + 1)

optimizer = tf.keras.optimizers.SGD(learning_rate=MyLRSchedule(0.001))

You can get the current learning rate of the optimizer using optimizer.learning_rate(step). This can be wrapped as a custom Keras callback and use wandb.log() with it.

class LRLogger(tf.keras.callbacks.Callback):
    def __init__(self, optimizer):
      super(LRLogger, self).__init__()
      self.optimizer = optimizer

    def on_epoch_end(self, epoch, logs):
      lr = self.optimizer.learning_rate(self.optimizer.iterations)
      wandb.log({"lr": lr}, commit=False)

Note that in the wandb.log call I have used commit=False argument. This will ensure that every metric is logged at the same time step. More on it here.

Call model.fit().

tf.keras.backend.clear_session()
model = some_model()

model.compile(optimizer, 'categorical_crossentropy', metrics=['acc'])

wandb.init(entity='wandb-user-id', project='my-project', job_type='train')

_ = model.fit(trainloader,
          epochs=EPOCHS,
          validation_data=testloader,
          callbacks=[WandbCallback(), # using WandbCallback to log default metrics.
                     LRLogger(optimizer)]) # using callback to log learning rate.

wandb.finish()

Here's the W&B media panel:

enter image description here

like image 105
ayush thakur Avatar answered Mar 29 '23 23:03

ayush thakur


I would like to complete ayush-thakur answer. As the scheduler updates the learning rate at each batch/step, and not at each epoch, the logger should retrieve the learning rate at epoch * steps_per_epoch, where steps_per_epoch is the number of batchs per epoch. This value is stored in optimizer.iterations.

Picking up @ayush-thakur code sample and changing the on_epoch_end function:

class LRLogger(tf.keras.callbacks.Callback):
    def __init__(self, optimizer):
        super(LRLogger, self).__init__()
        self.optimizer = optimizer

    def on_epoch_end(self, epoch, logs):
        lr = self.optimizer.learning_rate(self.optimizer.iterations)
        wandb.log({"lr": lr}, commit=False)

You can then use this callback in model.fit training process.

Note that the code above will return the learning rate of the last batch of each epoch. To get the learning rate of the first batch of each epoch, replace on_epoch_end with on_epoch_begin.

like image 24
cofri Avatar answered Mar 29 '23 22:03

cofri