Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to add if condition in a TensorFlow graph?

Let's say I have following code:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input") condition = tf.placeholder("int32", shape=[1, 1], name = "condition") W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights") b = tf.Variable(tf.zeros([label_option]), name = "bias")  if condition > 0:     y = tf.nn.softmax(tf.matmul(x, W) + b) else:     y = tf.nn.softmax(tf.matmul(x, W) - b)   

Would the if statement work in the calculation (I do not think so)? If not, how can I add an if statement into the TensorFlow calculation graph?

like image 201
Yee Liu Avatar asked Mar 06 '16 21:03

Yee Liu


Video Answer


2 Answers

You're correct that the if statement doesn't work here, because the condition is evaluated at graph construction time, whereas presumably you want the condition to depend on the value fed to the placeholder at runtime. (In fact, it will always take the first branch, because condition > 0 evaluates to a Tensor, which is "truthy" in Python.)

To support conditional control flow, TensorFlow provides the tf.cond() operator, which evaluates one of two branches, depending on a boolean condition. To show you how to use it, I'll rewrite your program so that condition is a scalar tf.int32 value for simplicity:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input") condition = tf.placeholder(tf.int32, shape=[], name="condition") W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights") b = tf.Variable(tf.zeros([label_option]), name="bias")  y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b) 
like image 56
mrry Avatar answered Sep 26 '22 22:09

mrry


TensorFlow 2.0

TF 2.0 introduces a feature called AutoGraph which lets you JIT compile python code into Graph executions. This means you can use python control flow statements (yes, this includes if statements). From the docs,

AutoGraph supports common Python statements like while, for, if, break, continue and return, with support for nesting. That means you can use Tensor expressions in the condition of while and if statements, or iterate over a Tensor in a for loop.

You will need to define a function implementing your logic and annotate it with tf.function. Here is a modified example from the documentation:

import tensorflow as tf  @tf.function def sum_even(items):   s = 0   for c in items:     if tf.equal(c % 2, 0):          s += c   return s  sum_even(tf.constant([10, 12, 15, 20])) #  <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42> 
like image 33
cs95 Avatar answered Sep 22 '22 22:09

cs95