Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I add an optional input to a graph in TensorFlow?

Tags:

tensorflow

I basically want to have the option to feed input to the middle of the graph and compute the output going from there. One idea I had is to use tf.placeholder_with_default that defaults to a zero tensor. Then I could mix the optional inputs using addition, however addition on a large shape this seems to be a lot of unnecessary computation. Are there better ways of accomplishing that?

input_enabled = tf.placeholder_with_default(tf.constant(1.), [1])

input_shape = [None, in_size]
input = tf.placeholder_with_default(tf.zeros(input_shape), input_shape)
// ...
bottleneck_shape = [None, bottleneck_size]
bottleneck = input_enabled * f(prev_layer) + tf.placeholder_with_default(tf.zeros(bottleneck_shape), bottleneck_shape)
// ...

// Using graph with input at first layer:
sess.run([output], feed_dict={input: x})

// Using graph with input at bottleneck layer:
sess.run([output], feed_dict={bottleneck: b, input_enabled: 0.})
like image 516
Lenar Hoyt Avatar asked Jun 20 '16 13:06

Lenar Hoyt


1 Answers

I understand better thanks to your code.

Basically the schema is:

       input       <- you can feed here
         |        
     (encoder)
         |
     bottleneck    <- you can also feed here instead
         |
     (decoder)
         |
       output

You want two use cases:

  1. train: feed an image into input, compute the output
  2. test: feed a code into the bottleneck, compute the output

You don't need to create a placeholder for bottleneck, because sess.run() allows you to feed values to non placeholders in the Graph:

input_shape = [None, in_size]
input = tf.placeholder(tf.float32, input_shape)
# ...

bottleneck = f(prev_layer)  # of shape [None, bottleneck_size]
# ...

# Using graph with input at first layer:
sess.run([output], feed_dict={input: x})

# Using graph with input at bottleneck layer:
sess.run([output], feed_dict={bottleneck: b})

From the documentation of sess.run():

The optional feed_dict argument allows the caller to override the value of tensors in the graph. Each key in feed_dict can be one of the following types:

If the key is a Tensor, the value may be a Python scalar, string, list, or numpy ndarray that can be converted to the same dtype as that tensor. Additionally, if the key is a placeholder, the shape of the value will be checked for compatibility with the placeholder.

like image 73
Olivier Moindrot Avatar answered Oct 25 '22 03:10

Olivier Moindrot