Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What's the difference between tf.cond and if-else?

Tags:

tensorflow

What difference between tf.cond and if-else?

Scenario 1

import tensorflow as tf

x = 'x'
y = tf.cond(tf.equal(x, 'x'), lambda: 1, lambda: 0)
with tf.Session() as sess:
    print(sess.run(y))
x = 'y'
with tf.Session() as sess:
    print(sess.run(y))

Scenario 2

import tensorflow as tf

x = tf.Variable('x')
y = tf.cond(tf.equal(x, 'x'), lambda: 1, lambda: 0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(y))

tf.assign(x, 'y')
with tf.Session() as sess:
    init.run()
    print(sess.run(y))

The outputs are both 1.

Does it mean only tf.placeholder can work, and not all the tensor, such as tf.variable? When should I choose if-else condition and when to use tf.cond? What are the diffences between them?

like image 284
gaussclb Avatar asked Aug 05 '17 03:08

gaussclb


People also ask

What does TF cond do?

cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred . tf. cond supports nested structures as implemented in tensorflow. python.

What is TF Where?

The tf. where() function is used to returns the elements, either of first tensor or second tensor depending on the specified condition. If the given condition is true, it select from the first tensor else select form the second tensor. Syntax: tf.where (condition, a, b)


2 Answers

tf.cond is evaluated at the runtime, whereas if-else is evaluated at the graph construction time.

If you want to evaluate your condition depending on the value of the tensor at the runtime, tf.cond is the best option.

like image 167
Ishant Mrinal Avatar answered Oct 18 '22 04:10

Ishant Mrinal


Did you mean if ... else in Python vs. tf.cond?

You can use if ... else for creating different graph for different external conditions. For example you can make one python script for graphs with 1, 2, 3 hidden layers, and use command line parameters for select which one use.

tf.cond is for add condition block to the graph. For example, you can define Huber function by code like this:

import tensorflow as tf
delta = tf.constant(1.)
x = tf.placeholder(tf.float32, shape=())

def left(x):
    return tf.multiply(x, x) / 2.
def right(x):
    return tf.multiply(delta, tf.abs(x) - delta / 2.)

hubber = tf.cond(tf.abs(x) <= delta,  lambda: left(x),  lambda: right(x))

and calculation in Graph will go by different branch for different input data.

sess = tf.Session()
with sess.as_default():
    sess.run(tf.global_variables_initializer())
    print(sess.run(hubber, feed_dict = {x: 0.5}))
    print(sess.run(hubber, feed_dict = {x: 1.0}))
    print(sess.run(hubber, feed_dict = {x: 2.0}))

> 0.125
> 0.5
> 1.5
like image 2
Vladimir Bystricky Avatar answered Oct 18 '22 03:10

Vladimir Bystricky