Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In Tensorflow, how to use tf.gather() for the last dimension?

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I want to select slices based on the last dimension, such as

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]

However, tf.gather(L, [0, 2,3,8]) seems to only work for the first dimension (right?) Can anyone tell me how to do it?

like image 519
YW P Kwon Avatar asked Apr 21 '16 09:04

YW P Kwon


People also ask

How does TF gather work?

tf. gather extends indexing to handle tensors of indices. More generally: The output shape has the same shape as the input, with the indexed-axis replaced by the shape of the indices. # is equal to the slice of `params` along `axis` at the index.

How do you add dimensions in TF?

You can use tf. expand_dims() to add a new dimension. You can also use tf. reshape() for this, but would recommend you to use expand_dims, as this will also carry some values to new dimension if new shape can be satisfied.

How do you flatten a tensor in TensorFlow?

To flatten the tensor, we're going to use the TensorFlow reshape operation. So tf. reshape, we pass in our tensor currently represented by tf_initial_tensor_constant, and then the shape that we're going to give it is a -1 inside of a Python list.

What is dimension in TensorFlow?

In Tensorflow, all the computations involve tensors. A tensor is a vector or matrix of n-dimensions that represents all types of data. All values in a tensor hold identical data type with a known (or partially known) shape. The shape of the data is the dimensionality of the matrix or array.


3 Answers

As of TensorFlow 1.3 tf.gather has an axis parameter, so the various workarounds here are no longer necessary.

https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223

like image 113
rryan Avatar answered Oct 08 '22 09:10

rryan


There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

For now you can:

  1. transpose your matrix so that dimension to gather is first (transpose is expensive)

  2. reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

  3. use gather_nd. Will still need to turn your column indices into list of individual element indices.
like image 21
Yaroslav Bulatov Avatar answered Oct 08 '22 09:10

Yaroslav Bulatov


With gather_nd you can now do this as follows:

cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)

Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:

x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]),  # flatten input
              idx_flattened)  # use flattened indices

with tf.Session(''):
  print y.eval()  # [2 4 9]

The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).

like image 9
Andrei Pokrovsky Avatar answered Oct 08 '22 09:10

Andrei Pokrovsky