I am trying to produce a very easy example for combination of TensorArray and while_loop:
# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(tf.float32, size=matrix_rows)
ta = ta.unstack(matrix)
init_state = (0, ta)
condition = lambda i, _: i < n
body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2))
# run the graph
with tf.Session() as sess:
(n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))})
print (ta_final.stack())
But I am getting the following error:
ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32).
Anyone has on idea what is the problem?
There are several things in your code to point out. First, you don't need to unstack the matrix into the TensorArray
to use it inside the loop, you can safely reference the matrix Tensor
inside the body and index it using matrix[i]
notation. Another issue is the different data type between your matrix (tf.int32
) and the TensorArray
(tf.float32
), based on your code you're multiplying the matrix ints by 2 and writing the result into the array so it should be int32 as well. Finally, when you wish to read the final result of the loop, the correct operation is TensorArray.stack()
which is what you need to run in your session.run
call.
Here's a working example:
import numpy as np
import tensorflow as tf
# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows)
init_state = (0, ta)
condition = lambda i, _: i < matrix_rows
body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2))
n, ta_final = tf.while_loop(condition, body, init_state)
# get the final result
ta_final_result = ta_final.stack()
# run the graph
with tf.Session() as sess:
# print the output of ta_final_result
print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)})
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With