Given...
A
of shape [m, n]
I
of shape [m]
I want to get a list J
of elements from A
where
J[i] = A[i, I[i]]
.
That is, I
holds the index of the element to select from each row in A
.
Context: I already have the argmax(A, 1)
and now I also want the max
.
I know that I can just use reduce_max
.
And after trying around for a bit I also came up with this:
J = tf.gather_nd(A,
tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
Where the to_int64
is needed because range only produces int32
and argmax
only produces int64
.
None of the two strike me as particularly elegant.
One has runtime overhead (probably about factor n
) and the other has an unknown factor cognitive overhead. Am I missing something here?
The gather()
function provides a way to do it:
r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
This is a rather late answer, but could doing
mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False)
elements = tf.boolean_mask(A, mask)
Accomplish what you're looking for?
edit: I should point out that this is NOT a good idea if A
is already a very large tensor, as this ends up making a dense matrix.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With