I have a 2D tensor A
with shape [batch_size, D]
, and a 1D tensor B
with shape [batch_size]
. Each element of B
is a column index of A
, for each row of A
, eg. B[i] in [0,D)
.
What is the best way in tensorflow to get the values A[B]
For example:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
with desired output:
some_slice_func(A, B) -> [2,4]
There is another constraint. In practice, batch_size
is actually None
.
Thanks in advance!
I was able to get it working using a linear index:
def vector_slice(A, B):
""" Returns values of rows i of A at column B[i]
where A is a 2D Tensor with shape [None, D]
and B is a 1D Tensor with shape [None]
with type int32 elements in [0,D)
Example:
A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
[3,4]]
"""
linear_index = (tf.shape(A)[1]
* tf.range(0,tf.shape(A)[0]))
linear_A = tf.reshape(A, [-1])
return tf.gather(linear_A, B + linear_index)
This feels slightly hacky though.
If anyone knows a better (as in clearer or faster) please also leave an answer! (I won't accept my own for a while)
Code for what @Eugene Brevdo said:
def vector_slice(A, B):
""" Returns values of rows i of A at column B[i]
where A is a 2D Tensor with shape [None, D]
and B is a 1D Tensor with shape [None]
with type int32 elements in [0,D)
Example:
A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
[3,4]]
"""
B = tf.expand_dims(B, 1)
range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1)
ind = tf.concat([range, B], 1)
return tf.gather_nd(A, ind)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With