Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to select rows from a 3-D Tensor in TensorFlow?

Tags:

tensorflow

I have a tensor logits with the dimensions [batch_size, num_rows, num_coordinates] (i.e. each logit in the batch is a matrix). In my case batch size is 2, there's 4 rows and 4 coordinates.

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

I want to select the first and second row of the first batch and the second and fourth row of the second batch.

indices = tf.constant([[0, 1], [1, 3]])

So the desired output would be

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0]],
                     [[15.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

How do I do this using TensorFlow? I tried using tf.gather(logits, indices) but it did not return what I expected. Thanks!

like image 646
Clash Avatar asked Mar 18 '16 15:03

Clash


2 Answers

This is possible in TensorFlow, but slightly inconvenient, because tf.gather() currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather():

logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])

# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))

selected_rows = tf.gather(flattened_logits, flattened_indices)

result = tf.reshape(selected_rows,
                    tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                                  tf.shape(logits)[2:]]))

Note that, since this uses tf.reshape() and not tf.transpose(), it doesn't need to modify the (potentially large) data in the logits tensor, so it should be fairly efficient.

like image 188
mrry Avatar answered Nov 09 '22 02:11

mrry


mrry's answer is great, but I think with the function tf.gather_nd the problem can be solved with much fewer lines of code (probably this function was not yet available at the time of mrry's writing):

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])

result = tf.gather_nd(logits, indices)
with tf.Session() as sess:
    print(sess.run(result))

This will print

[[[ 10.  10.  20.  20.]
  [ 11.  10.  10.  30.]]

 [[ 15.  11.  11.  21.]
  [ 17.  11.  11.  21.]]]

tf.gather_nd should be available as of v0.10. Check out this github issue for more discussions on this.

like image 45
kafman Avatar answered Nov 09 '22 02:11

kafman