Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow: check if a scalar boolean tensor is True

I want to control the execution of a function using a placeholder, but keep getting an error "Using a tf.Tensor as a Python bool is not allowed". Here is the code that produces this error:

import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

I changed if c to if c is not None without luck. How can I control foo by turning on and off the placeholder a then?

Update: as @nessuno and @nemo point out, we must use tf.cond instead of if..else. The answer to my question is to re-design my function like this:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 
like image 647
Tu Bui Avatar asked Apr 06 '17 19:04

Tu Bui


2 Answers

You have to use tf.cond to define a conditional operation within the graph and change, thus, the flow of the tensors.

import tensorflow as tf

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()
print(res)

10

like image 106
nessuno Avatar answered Oct 31 '22 05:10

nessuno


The actual execution is not done in Python but in the TensorFlow backend which you supply with the computation graph it is supposed to execute. This means that every condition and flow control you want to apply has to be formulated as a node in the computation graph.

For if conditions there is the cond operation:

b = tf.cond(c, 
           lambda: tf.constant(10), 
           lambda: tf.constant(0))
like image 35
nemo Avatar answered Oct 31 '22 05:10

nemo