Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Conditional execution in TensorFlow

Tags:

tensorflow

How can I choose to execute a portion of the graph based on a condition?

I have a part of my network which is to be executed only if a placeholder value is provided in feed_dict. An alternate path is taken if the value is not provided. How do I go about implementing this using tensorflow?

Here are the relevant portions of my code:

sess.run(accuracy, feed_dict={inputs: mnist.test.images, outputs: mnist.test.labels})

N = tf.shape(outputs)
    cost = 0
    if N > 0:
        y_N = tf.slice(h_c, [0, 0], N)
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y_N, outputs, name='xentropy')
        cost = tf.reduce_mean(cross_entropy, name='xentropy_mean')

In the above code, I'm looking for something to use in the place of if N > 0:

like image 411
Rinu Boney Avatar asked Nov 13 '15 06:11

Rinu Boney


People also ask

What is eager execution TensorFlow?

With eager execution enabled, TensorFlow functions execute operations immediately (as opposed to adding to a graph to be executed later in a tf. compat. v1. Session ) and return concrete values (as opposed to symbolic references to a node in a computational graph).

How do I enable eager execution in TensorFlow?

With TensorFlow 2. x, Eager Execution is enabled by default, and allows TensorFlow code to be run and evaluated line by line.

What is TF Cond?

cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred . tf. cond supports nested structures as implemented in tensorflow. python.

What is AutoGraph in TensorFlow?

AutoGraph transforms a subset of Python which operates on TensorFlow objects into equivalent TensorFlow graph code. When executing the graph, it has the same effect as if you ran the original code in eager mode.


1 Answers

Hrm. It's possible that what you want is tf.control_flow_ops.cond() https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L597

But that's not exported into the tf namespace, and I'm answering without checking how guaranteed-stable this interface is, but it's used in released models, so go for it. :)

However: Because you actually know in advance what path you want when you construct the feed_dict, you could also take a different approach of invoking a separate path through your model. The standard way to do this is to, e.g., set up code like:

def model(input, n_greater_than):
  ... cleverness ...
  if n_greater_than:
     ... other cleverness...
  return tf.reduce_mean(input)


out1 = model(input, True)
out2 = model(input, False)

And then pull the out1 or out2 nodes depending upon what you know when you're about to run your computation and set the feed_dict. Remember that by default, if the model references the same variables (create them outside the model() func), then you'll basically have two separate paths through.

You can see an example of this in the convolutional mnist example: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/mnist/convolutional.py#L165

I'm a fan of doing it this way without introducing control flow dependencies if you can.

like image 108
dga Avatar answered Jan 27 '23 10:01

dga