Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reshape tensor using placeholder value

Tags:

tensorflow

I want to reshape a tensor using the [int, -1] notation (to flatten an image, for example). But I don't know the first dimension ahead of time. One use case is train on a large batch, then evaluate on a smaller batch.

Why does this give the following error: got list containing Tensors of type '_Message'?

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, shape=[None, 28, 28])
batch_size = tf.placeholder(tf.int32)

def reshape(_batch_size):
    return tf.reshape(x, [_batch_size, -1])

reshaped = reshape(batch_size)


with tf.Session() as sess:
    sess.run([reshaped], feed_dict={x: np.random.rand(100, 28, 28), batch_size: 100})

    # Evaluate
    sess.run([reshaped], feed_dict={x: np.random.rand(8, 28, 28), batch_size: 8})

Note: when I have the reshape outside of the function it seems to work, but I have very large models that I use multiple times, so I need to keep them in a function and pass the dim using an argument.

like image 217
jstaker7 Avatar asked Feb 13 '16 01:02

jstaker7


People also ask

What does placeholder do in TensorFlow?

A placeholder is simply a variable that we will assign data to at a later date. It allows us to create our operations and build our computation graph, without needing the data. In TensorFlow terminology, we then feed data into the graph through these placeholders.

How do you flatten a TensorFlow tensor?

To flatten the tensor, we're going to use the TensorFlow reshape operation. So tf. reshape, we pass in our tensor currently represented by tf_initial_tensor_constant, and then the shape that we're going to give it is a -1 inside of a Python list.


1 Answers

To make this work, replace the function:

def reshape(_batch_size):
    return tf.reshape(x, [_batch_size, -1])

…with the function:

def reshape(_batch_size):
    return tf.reshape(x, tf.pack([_batch_size, -1]))

The reason for the error is that tf.reshape() expects a value that is convertible to a tf.Tensor as its second argument. TensorFlow will automatically convert a list of Python numbers to a tf.Tensor but will not automatically convert a mixed list of numbers and tensors (such as a tf.placeholder())—instead raising the somewhat unintuitive error message you saw.

The tf.pack() op takes a list of objects convertible to a tensor, and converts each element individually, so it can handle the combination of a placeholder and an integer.

like image 106
mrry Avatar answered Oct 05 '22 11:10

mrry