Given...
A of shape [m, n]
I of shape [m]
I want to get a list J of elements from A where 
J[i] = A[i, I[i]].
That is, I holds the index of the element to select from each row in A.
Context: I already have the argmax(A, 1) and now I also want the max.
I know that I can just use reduce_max.
And after trying around for a bit I also came up with this:
J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
Where the to_int64 is needed because range only produces int32 and argmax only produces int64.
None of the two strike me as particularly elegant.
One has runtime overhead (probably about factor n) and the other has an unknown factor cognitive overhead. Am I missing something here?
The gather() function provides a way to do it:
r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
This is a rather late answer, but could doing
mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False)
elements = tf.boolean_mask(A, mask)
Accomplish what you're looking for?
edit: I should point out that this is NOT a good idea if A is already a very large tensor, as this ends up making a dense matrix.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With