Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Writing piece-wise functions in TensorFlow / if then in TensorFlow

How do I write a piece-wise TensorFlow function i.e. a function that has an if-statement inside it?

Current code

import tensorflow as tf

my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
  x = tf.Variable(tf.random_normal([100, 1]))
  output = tf.map_fn(my_fn, x)

Error:

TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use the logical TensorFlow ops to test the value of a tensor.

like image 786
J Wang Avatar asked Dec 05 '22 00:12

J Wang


2 Answers

tf.select is no more working as indicated by this thread as well https://github.com/tensorflow/tensorflow/issues/8647

Something that worked for me was tf.where

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
like image 87
Priyank Pathak Avatar answered Dec 06 '22 13:12

Priyank Pathak


You should take a look at tf.where.

For your example, you could do:

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)

EDIT: move from tf.select to tf.where

like image 33
Olivier Moindrot Avatar answered Dec 06 '22 15:12

Olivier Moindrot