Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Order of sess.run([op1, op2...]) in Tensorflow

I wonder what's the running order of the op list in sess.run(ops_list, ...). for example: for a typical classification scenario: _, loss = sess.run([train_op, loss_op]), if train_op run first,then the loss is the loss after current backprop. But if loss run first, then the loss is the loss before current backprop. anyone help me? Thanks.

like image 226
Long Liangqu Avatar asked Nov 06 '18 03:11

Long Liangqu


1 Answers

loss will be the value of the loss before the update caused by train_op. Note that loss_op is an input to the optimizer, so it necessarily goes "before" train_op in the graph. loss_op is computed with the variable values at the beginning of the run operation. If you wanted to compute the loss after train_op, you could do that for example using a tf.control_dependencies block with the optimizer and computing the loss again, but in that case you would be doing two forward passes of your model on each step, with the associated cost. Usually, if you just want to plot the loss for monitorization or something like that it is okay to use the value from the previous step.

For further explanation, in general the order in which TensorFlow operations are executed is only guaranteed to the extent that these operations depend on each other, and it is not related to the order in which the are passed to run. In your case, train_op depends on loss_op, so loss_op has to go first. However, in other cases operations do not depend directly on each other, and in that case the order is not guaranteed. Many times that does not really matter, but in some cases it does. Consider the next example:

import tensorflow as tf

v = tf.Variable(0)
v2 = 2 * v
v_update = v.assign(v + 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(5):
        print(sess.run([v_update, v2]))

A run in my computer produced this output:

[1, 0]
[2, 2]
[3, 4]
[4, 8]
[5, 10]

As you see, v2 is sometimes twice the updated value and sometimes twice the non-updated value. If for example we wanted to make sure that v2 always runs before v_update we could do:

import tensorflow as tf

v = tf.Variable(0)
v2 = 2 * v
with tf.control_dependencies([v2]):
    v_update = v.assign(v + 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(5):
        print(sess.run([v_update, v2]))

Which consistently produces:

[1, 0]
[2, 2]
[3, 4]
[4, 6]
[5, 8]
like image 86
jdehesa Avatar answered Sep 25 '22 12:09

jdehesa