How do I write a piece-wise TensorFlow function i.e. a function that has an if-statement inside it?
Current code
import tensorflow as tf
my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
x = tf.Variable(tf.random_normal([100, 1]))
output = tf.map_fn(my_fn, x)
Error:
TypeError: Using a tf.Tensor
as a Python bool
is not allowed. Use if t is not None:
instead of if t:
to test if a tensor is defined, and use the logical TensorFlow ops to test the value of a tensor.
tf.select
is no more working as indicated by this thread as well
https://github.com/tensorflow/tensorflow/issues/8647
Something that worked for me was tf.where
condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
You should take a look at tf.where
.
For your example, you could do:
condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
EDIT: move from tf.select
to tf.where
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With