What difference between tf.cond and if-else?
import tensorflow as tf
x = 'x'
y = tf.cond(tf.equal(x, 'x'), lambda: 1, lambda: 0)
with tf.Session() as sess:
print(sess.run(y))
x = 'y'
with tf.Session() as sess:
print(sess.run(y))
import tensorflow as tf
x = tf.Variable('x')
y = tf.cond(tf.equal(x, 'x'), lambda: 1, lambda: 0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run()
print(sess.run(y))
tf.assign(x, 'y')
with tf.Session() as sess:
init.run()
print(sess.run(y))
The outputs are both 1
.
Does it mean only tf.placeholder can work, and not all the tensor, such as tf.variable? When should I choose if-else condition and when to use tf.cond? What are the diffences between them?
cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred . tf. cond supports nested structures as implemented in tensorflow. python.
The tf. where() function is used to returns the elements, either of first tensor or second tensor depending on the specified condition. If the given condition is true, it select from the first tensor else select form the second tensor. Syntax: tf.where (condition, a, b)
tf.cond
is evaluated at the runtime, whereas if-else
is evaluated at the graph construction time.
If you want to evaluate your condition depending on the value of the tensor at the runtime, tf.cond
is the best option.
Did you mean if ... else
in Python vs. tf.cond
?
You can use if ... else
for creating different graph for different external conditions. For example you can make one python script for graphs with 1, 2, 3
hidden layers, and use command line parameters for select which one use.
tf.cond
is for add condition block to the graph. For example, you can define Huber function by code like this:
import tensorflow as tf
delta = tf.constant(1.)
x = tf.placeholder(tf.float32, shape=())
def left(x):
return tf.multiply(x, x) / 2.
def right(x):
return tf.multiply(delta, tf.abs(x) - delta / 2.)
hubber = tf.cond(tf.abs(x) <= delta, lambda: left(x), lambda: right(x))
and calculation in Graph will go by different branch for different input data.
sess = tf.Session()
with sess.as_default():
sess.run(tf.global_variables_initializer())
print(sess.run(hubber, feed_dict = {x: 0.5}))
print(sess.run(hubber, feed_dict = {x: 1.0}))
print(sess.run(hubber, feed_dict = {x: 2.0}))
> 0.125
> 0.5
> 1.5
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With