I have a tensor logits
with the dimensions [batch_size, num_rows, num_coordinates]
(i.e. each logit in the batch is a matrix). In my case batch size is 2, there's 4 rows and 4 coordinates.
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
I want to select the first and second row of the first batch and the second and fourth row of the second batch.
indices = tf.constant([[0, 1], [1, 3]])
So the desired output would be
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
How do I do this using TensorFlow? I tried using tf.gather(logits, indices)
but it did not return what I expected. Thanks!
This is possible in TensorFlow, but slightly inconvenient, because tf.gather()
currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather()
:
logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])
# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]
# Offset to add to each row in indices. We use `tf.expand_dims()` to make
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)
# Convert indices and logits into appropriate form for `tf.gather()`.
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))
selected_rows = tf.gather(flattened_logits, flattened_indices)
result = tf.reshape(selected_rows,
tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
tf.shape(logits)[2:]]))
Note that, since this uses tf.reshape()
and not tf.transpose()
, it doesn't need to modify the (potentially large) data in the logits
tensor, so it should be fairly efficient.
mrry's answer is great, but I think with the function tf.gather_nd
the problem can be solved with much fewer lines of code (probably this function was not yet available at the time of mrry's writing):
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])
result = tf.gather_nd(logits, indices)
with tf.Session() as sess:
print(sess.run(result))
This will print
[[[ 10. 10. 20. 20.]
[ 11. 10. 10. 30.]]
[[ 15. 11. 11. 21.]
[ 17. 11. 11. 21.]]]
tf.gather_nd
should be available as of v0.10. Check out this github issue for more discussions on this.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With