Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How TensorArray and while_loop work together in tensorflow?

I am trying to produce a very easy example for combination of TensorArray and while_loop:

# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(tf.float32, size=matrix_rows)
ta = ta.unstack(matrix)

init_state = (0, ta)
condition = lambda i, _: i < n
body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2))

# run the graph
with tf.Session() as sess:
    (n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))})
    print (ta_final.stack())

But I am getting the following error:

ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32).

Anyone has on idea what is the problem?

like image 733
E.Asgari Avatar asked May 08 '17 14:05

E.Asgari


1 Answers

There are several things in your code to point out. First, you don't need to unstack the matrix into the TensorArray to use it inside the loop, you can safely reference the matrix Tensor inside the body and index it using matrix[i] notation. Another issue is the different data type between your matrix (tf.int32) and the TensorArray (tf.float32), based on your code you're multiplying the matrix ints by 2 and writing the result into the array so it should be int32 as well. Finally, when you wish to read the final result of the loop, the correct operation is TensorArray.stack() which is what you need to run in your session.run call.

Here's a working example:

import numpy as np
import tensorflow as tf    

# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows)

init_state = (0, ta)
condition = lambda i, _: i < matrix_rows
body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2))
n, ta_final = tf.while_loop(condition, body, init_state)
# get the final result
ta_final_result = ta_final.stack()

# run the graph
with tf.Session() as sess:
    # print the output of ta_final_result
    print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)}) 
like image 128
sirfz Avatar answered Nov 15 '22 07:11

sirfz