If you have a conditional on some expensive operations you may want lazy behavior, i.e. only to evaluate the branch that is chosen.
The following works, and is lazy:
>>> a. tf.zeros(0)
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.argmax(a)).eval()
-1
You can see that it's lazy because the argmax is not evaluated, since it would cause an error. because the tensor taken argmax on is empty. If you move the argmax out of the lambda, it yields this very error:
>>> am = tf.argmax(a)
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(am, 1)).eval()
... Reduction axis 0 is empty in shape [0]
Which is not caused by the tf.add operation. Moving it inline and it works again:
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(tf.argmax(a), 1)).eval()
-1
The question then, is how to do lazy conditionals in a cleaner way?
When the conditional functions get long, the above approach gets a bit messy. What you can do is to define a lambda expression outside the conditional. Note that the following does not work in the Python interactive REPL, where it results in ValueError: Operation 'cond_14/Merge' has been marked as not fetchable..
It does work when you put the code into a python file and run in the normal way.
import tensorflow as tf
sess = tf.InteractiveSession()
a = tf.zeros(0)
fn = lambda: tf.argmax(a)
res = tf.cond(
tf.equal(tf.size(a), tf.constant(0)),
lambda: tf.constant(-1, dtype=tf.int64),
fn
).eval()
print(res)
res2 = tf.cond(
tf.equal(tf.size(a), tf.constant(0)),
lambda: tf.constant(-1, dtype=tf.int64),
lambda: tf.add(fn(), tf.constant(1, dtype=tf.int64))
).eval()
print(res2)
# Output:
# -1
# -1
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With