I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth]
, I want to select slices based on the last dimension, such as
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
However, tf.gather(L, [0, 2,3,8])
seems to only work for the first dimension (right?) Can anyone tell me how to do it?
tf. gather extends indexing to handle tensors of indices. More generally: The output shape has the same shape as the input, with the indexed-axis replaced by the shape of the indices. # is equal to the slice of `params` along `axis` at the index.
You can use tf. expand_dims() to add a new dimension. You can also use tf. reshape() for this, but would recommend you to use expand_dims, as this will also carry some values to new dimension if new shape can be satisfied.
To flatten the tensor, we're going to use the TensorFlow reshape operation. So tf. reshape, we pass in our tensor currently represented by tf_initial_tensor_constant, and then the shape that we're going to give it is a -1 inside of a Python list.
In Tensorflow, all the computations involve tensors. A tensor is a vector or matrix of n-dimensions that represents all types of data. All values in a tensor hold identical data type with a known (or partially known) shape. The shape of the data is the dimensionality of the matrix or array.
As of TensorFlow 1.3 tf.gather
has an axis
parameter, so the various workarounds here are no longer necessary.
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206
For now you can:
transpose your matrix so that dimension to gather is first (transpose is expensive)
reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back
gather_nd
. Will still need to turn your column indices into list of individual element indices.With gather_nd you can now do this as follows:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With