Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting the current learning rate from a tf.train.AdamOptimizer

Tags:

tensorflow

I'd like print out the learning rate for each training step of my nn.

I know that Adam has an adaptive learning rate, but is there a way i can see this (for visualization in tensorboard)

like image 501
kmace Avatar asked May 02 '16 19:05

kmace


People also ask

What does TF train AdamOptimizer do?

train. AdamOptimizer is compatible with eager mode and tf. function . When eager execution is enabled, learning_rate , beta1 , beta2 , and epsilon can each be a callable that takes no arguments and returns the actual value to use.

What is learning rate in Tensorflow?

Specifically, the learning rate is a configurable hyperparameter used in the training of neural networks that has a small positive value, often in the range between 0.0 and 1.0. The learning rate controls how quickly the model is adapted to the problem.

What is learning rate in Adam Optimizer?

Geoff Hinton, recommends setting γ to be 0.9, while a default value for the learning rate η is 0.001. This allows the learning rate to adapt over time, which is important to understand since this phenomena is also present in Adam.


2 Answers

All the optimizers have a private variable that holds the value of a learning rate.

In adagrad and gradient descent it is called self._learning_rate. In adam it is self._lr.

So you will just need to print sess.run(optimizer._lr) to get this value. Sess.run is needed because they are tensors.

like image 93
Salvador Dali Avatar answered Oct 05 '22 00:10

Salvador Dali


Sung Kim suggestion worked for me, my exact steps were:

lr = 0.1 step_rate = 1000 decay = 0.95  global_step = tf.Variable(0, trainable=False) increment_global_step = tf.assign(global_step, global_step + 1) learning_rate = tf.train.exponential_decay(lr, global_step, step_rate, decay, staircase=True)  optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=0.01) trainer = optimizer.minimize(loss_function)  # Some code here  print('Learning rate: %f' % (sess.run(trainer ._lr))) 
like image 28
X. Serra Avatar answered Oct 05 '22 01:10

X. Serra