Elegant Way to Select one Element per Row in Tensorflow




  • a Matrix A of shape [m, n]
  • a tensor I of shape [m]

I want to get a list J of elements from A where J[i] = A[i, I[i]].

That is, I holds the index of the element to select from each row in A.

Context: I already have the argmax(A, 1) and now I also want the max. I know that I can just use reduce_max. And after trying around for a bit I also came up with this:

J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))

Where the to_int64 is needed because range only produces int32 and argmax only produces int64.

None of the two strike me as particularly elegant. One has runtime overhead (probably about factor n) and the other has an unknown factor cognitive overhead. Am I missing something here?

2 Answers

The gather() function provides a way to do it:

r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
This is a rather late answer, but could doing

mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False)
elements = tf.boolean_mask(A, mask)

Accomplish what you're looking for?

edit: I should point out that this is NOT a good idea if A is already a very large tensor, as this ends up making a dense matrix.

