Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reset tensorflow Optimizer

I am loading from a saved model and I would like to be able to reset a tensorflow optimizer such as an Adam Optimizer. Ideally something like:

sess.run([tf.initialize_variables(Adamopt)])

or

sess.run([Adamopt.reset])

I have tried looking for an answer but have yet to find any way to do it. Here's what I've found which don't address the issue: https://github.com/tensorflow/tensorflow/issues/634

In TensorFlow is there any way to just initialize uninitialised variables?

Tensorflow: Using Adam optimizer

I basically just want a way to reset the "slot" variables in the Adam Optimizer.

Thanks

like image 205
Steven Avatar asked Sep 21 '16 04:09

Steven


3 Answers

In tensorflow 2.x, e.g., Adam optimizer, you can reset it like this:

for var in optimizer.variables():
    var.assign(tf.zeros_like(var))
like image 166
EdisonLeejt Avatar answered Nov 13 '22 02:11

EdisonLeejt


This question also bothered me for quite a while. Actually it's quite easy, you just define an operation to reset the current state of an optimizer which can be obtained by the variables() method, something like this:

optimizer = tf.train.AdamOptimizer(0.1, name='Optimizer')
reset_optimizer_op = tf.variables_initializer(optimizer.variables())

Whenever you need to reset the optimizer, run:

sess.run(reset_optimizer_op)

Official explanation of variables():

A list of variables which encode the current state of Optimizer. Includes slot variables and additional global variables created by the optimizer in the current default graph.

e.g. for AdamOptimizer basically you will get the first and second moment(with slot_name 'm' and 'v') of all trainable variables, as long as beta1_power and beta2_power.

like image 27
Yuchi Yang Avatar answered Nov 13 '22 02:11

Yuchi Yang


The simplest way I found was to give the optimizer its own variable scope and then run

optimizer_scope = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 "scope/prefix/for/optimizer")
sess.run(tf.initialize_variables(optimizer_scope))

idea from freeze weights

like image 45
Steven Avatar answered Nov 13 '22 01:11

Steven