Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow while loop : dealing with lists

import tensorflow as tf

array = tf.Variable(tf.random_normal([10]))
i = tf.constant(0)
l = []

def cond(i,l):
   return i < 10

def body(i,l):
   temp = tf.gather(array,i)
   l.append(temp)
   return i+1,l

index,list_vals = tf.while_loop(cond, body, [i,l])

I want to process a tensor array in the similar way as described in the above code. In the body of the while loop I want to process the array by element by element basis to apply some function. For demonstration, I have given a small code snippet. However, it is giving an error message as follows.

ValueError: Number of inputs and outputs of body must match loop_vars: 1, 2

Any help in resolving this is appreciated.

Thanks

like image 602
webNash Avatar asked Dec 20 '16 01:12

webNash


2 Answers

Citing the documentation:

loop_vars is a (possibly nested) tuple, namedtuple or list of tensors that is passed to both cond and body

You cannot pass regular python array as a tensor. What you can do, is:

i = tf.constant(0)
l = tf.Variable([])

def body(i, l):                                               
    temp = tf.gather(array,i)
    l = tf.concat([l, [temp]], 0)
    return i+1, l

index, list_vals = tf.while_loop(cond, body, [i, l],
                                 shape_invariants=[i.get_shape(),
                                                   tf.TensorShape([None])])

The shape invariants are there, because normally tf.while_loop expects the shapes of tensors inside while loop won't change.

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(list_vals)
Out: array([-0.38367489, -1.76104736,  0.26266089, -2.74720812,  1.48196387,
            -0.23357525, -1.07429159, -1.79547787, -0.74316853,  0.15982138], 
           dtype=float32)
like image 96
sygi Avatar answered Sep 21 '22 11:09

sygi


TF offers a TensorArray to deal with such cases. From the doc,

Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.

This class is meant to be used with dynamic iteration primitives such as while_loop and map_fn. It supports gradient back-propagation via special "flow" control flow dependencies.

Here is an example,

import tensorflow as tf

array = tf.Variable(tf.random_normal([10]))
step = tf.constant(0)
output = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

def cond(step, output):
    return step < 10

def body(step, output):
    output = output.write(step, tf.gather(array, step))
    return step + 1, output

_, final_output = tf.while_loop(cond, body, loop_vars=[step, output])

final_output = final_output.stack()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(final_output))
like image 33
sudoer Avatar answered Sep 21 '22 11:09

sudoer