Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Flatten batch in tensorflow

Tags:

tensorflow

I have an input to tensorflow of shape [None, 9, 2] (where the None is batch).

To perform further actions (e.g. matmul) on it I need to transform it to [None, 18] shape. How to do it?

like image 717
Cactux Avatar asked Apr 16 '16 19:04

Cactux


People also ask

What is flatten in tensorflow?

Tensorflow flatten is the function available in the tensorflow library and reduces the input data into a single dimension instead of 2 dimensions. While doing so, it does not affect the batch size.

How do you flatten a TF tensor?

To flatten the tensor, we're going to use the TensorFlow reshape operation. So tf. reshape, we pass in our tensor currently represented by tf_initial_tensor_constant, and then the shape that we're going to give it is a -1 inside of a Python list.

What does a flatten layer do?

Flatten layer is used to make the multidimensional input one-dimensional, commonly used in the transition from the convolution layer to the full connected layer.

What does flatten layer do in Keras?

Flattens the input. Does not affect the batch size. Note: If inputs are shaped (batch,) without a feature axis, then flattening adds an extra channel dimension and output shape is (batch, 1) . data_format: A string, one of channels_last (default) or channels_first .


2 Answers

You can do it easily with tf.reshape() without knowing the batch size.

x = tf.placeholder(tf.float32, shape=[None, 9,2]) shape = x.get_shape().as_list()        # a list: [None, 9, 2] dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18 x2 = tf.reshape(x, [-1, dim])           # -1 means "all" 

The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().


Update: shape = [None, 3, None]

Thanks @kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.

x = tf.placeholder(tf.float32, shape=[None, 3, None]) dim = tf.reduce_prod(tf.shape(x)[1:]) x2 = tf.reshape(x, [-1, dim]) 

tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.

I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.

like image 161
weitang114 Avatar answered Oct 06 '22 19:10

weitang114


flat_inputs = tf.layers.flatten(inputs) 
like image 37
user149100 Avatar answered Oct 06 '22 19:10

user149100