I am trying to parallelize loop using tf.while_loop
. As suggested here, the parallel_iterations
argument doesn't make a difference in the eager mode. So I attempted to wrap tf.while_loop
with tf.function
. However, after adding the decorator,the behavior of the iteration variable changes.
For example, this piece of code works.
result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
result[iteration] = iteration
iteration += 1
return (iteration,)
tf.while_loop(c, print_fun, [iteration])
If I add the decorator, bug occurs.
result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
result[iteration] = iteration
iteration += 1
return (iteration,)
@tf.function
def run_graph():
iteration = tf.constant(0)
tf.while_loop(c, print_fun, [iteration])
run_graph()
From my debugging process, I found that variable iteration
changes from a tensor to a placeholder. Why is that? How should I modify the code to eliminate the bug?
Thanks.
The code in your first snippet (the one without the @tf.function
) takes advantage of TensorFlow 2's eager execution to manipulate a numpy array (i.e., your outer iteration
object) directly. With @tf.function
, this doesn't work because @tf.function tries to compile your code into a tf.Graph, which cannot operate on a numpy array directly (it can only process tensorflow tensors). To get around this issue, use a tf.Variable and keep assigning value into its slices.
With @tf.function
, what you are trying to do is actually achievable with simpler code, by taking advantage of @tf.function
's automatic Python-to-graph transformation feature (known as AutoGraph). You just write a normal Python while loop (using tf.less()
in lieu of the <
operator), and the while loop will be compiled by AutoGraph into a tf.while_loop under the hood.
The code looks something like:
result = tf.Variable(np.zeros([10], dtype=np.int32))
@tf.function
def run_graph():
i = tf.constant(0, dtype=tf.int32)
while tf.less(i, 10):
result[i].assign(i) # Performance may require tuning here.
i += 1
run_graph()
print(result.read_value())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With