Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow asks inputs for unnecessary placeholders when using tf.cond()

Consider the following code snippet that includes tensorflow tf.cond().

import tensorflow as tf
import numpy as np

bb = tf.placeholder(tf.bool)
xx = tf.placeholder(tf.float32, name='xx')
yy = tf.placeholder(tf.float32, name='yy')

zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

with tf.Session() as sess:
        dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
        print(sess.run(zz, feed_dict=dict1)) # works fine without errors

        dict2 = {bb:False, yy:np.array([1., 3, 4])}
        print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
                                             # provide an input for xx

In both cases, bb is False and evaluation of zz theoretically has no dependency on xx, but still tensorflow requires an input for xx. Even though it can be provided as a dummy array, it has to be matched with the shape of yy and is not as clean as dict2.

Can anybody suggest how to evaluate zz (using tf.cond() or any other approach) without providing a value for xx?

like image 404
Batta Avatar asked Jan 28 '23 16:01

Batta


1 Answers

You can define xx as a tf.Variable instead, giving it a default value (which will be used whenever xx is not fed with another value). A few things to notice:

  1. Although xx is not a placeholder - you can still treat it as if it were by feeding values into it through the feed_dict.
  2. Use validate_shape=False so that you can feed any shapes into xx.
  3. Use trainable=False so that xx is not optimized over (otherwise, an optimizer might change its default value to things like Nan, which may cause problems).
  4. Don't forget to initialize the values for xx, by using, e.g., tf.global_variables_initializer().

Here is the code:

import tensorflow as tf
import numpy as np

bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')

zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
   print(sess.run(zz, feed_dict=dict1))
   dict2 = {bb:False, yy:np.array([1., 3, 4])}
   print(sess.run(zz, feed_dict=dict2))
like image 64
Lior Avatar answered Jan 31 '23 07:01

Lior