I am trying to update a tf.Variable
inside a tf.while_loop()
, using tf.scatter_update()
. However, the result is the initial value instead of the updated value. Here is the sample code of what I am trying to do:
from __future__ import print_function
import tensorflow as tf
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)
tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)
with tf.Graph().as_default():
sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",[3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([begin,step]))
The result is: [array([0, 0, 0], dtype=int32), 10]
. However, I think the result should be [0, 0, 10]
. Am I doing something wrong here?
The problem here is that nothing in the loop body depends on your tf.scatter_update()
op, so it is never executed. The easiest way to make it work is to add a control dependency on the update to the return values:
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, 1, step, use_locking=None)
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([begin]):
return (sequence_len, step+1)
Note that this problem isn't unique to loops in TensorFlow. If you had just defined an tf.scatter_update()
op called begin
but call sess.run()
on it, or something that depends on it, then the update won't happen. When you're using the tf.while_loop()
there's no way to run the operations defined in the loop body directly, so the easiest way to get a side effect is to add a control dependency.
Note that the final result is [0, 9, 0]
: each iteration assigns the current step to begin[1]
, and in the last iteration the value of the current step is 9
(the condition is false when step == 10
).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With