Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) in tensorflow

What is purpose of tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) in tensorflow?

With more context:

    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_fn, var_list=tf.trainable_variables())
like image 525
mrgloom Avatar asked Dec 14 '18 02:12

mrgloom


2 Answers

The method tf.control_dependencies allow to ensure that the operations used as inputs of the context manager are run before the operations defined inside the context manager.

For example:

count = tf.get_variable("count", shape=(), initializer=tf.constant_initializer(1), trainable=False)
count_increment = tf.assign_add(count, 1)
c = tf.constant(2.)
with tf.control_dependencies([count_increment]):
    d = c + 3
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("eval count", count.eval())
    print("eval d", d.eval())
    print("eval count", count.eval())

This prints:

eval count 1
eval d 5.0 # Running d make count_increment operation being run
eval count 2 # count_increment operation has be run and now count hold 2.

So in your case, each time you run the train_op operation it will first run all the operations defined in the tf.GraphKeys.UPDATE_OPS collection.

like image 172
pfm Avatar answered Nov 11 '22 00:11

pfm


If you use for example tf.layers.batch_normalization the layer will create some Ops, that need to be run every training step (update the moving average and variance of the variables).

tf.GraphKeys.UPDATE_OPS is a collection of these variables and if you put it in the tf.control_dependencies block, these Ops will get executed before the training op is run.

https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

like image 45
Jindra Helcl Avatar answered Nov 11 '22 02:11

Jindra Helcl