Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow indexing into 2d tensor with 1d tensor

Tags:

tensorflow

I have a 2D tensor A with shape [batch_size, D] , and a 1D tensor B with shape [batch_size]. Each element of B is a column index of A, for each row of A, eg. B[i] in [0,D).

What is the best way in tensorflow to get the values A[B]

For example:

A = tf.constant([[0,1,2],
                 [3,4,5]])
B = tf.constant([2,1])

with desired output:

some_slice_func(A, B) -> [2,4]

There is another constraint. In practice, batch_size is actually None.

Thanks in advance!

like image 806
Taaam Avatar asked Sep 13 '25 10:09

Taaam


2 Answers

I was able to get it working using a linear index:

def vector_slice(A, B):
    """ Returns values of rows i of A at column B[i]

    where A is a 2D Tensor with shape [None, D] 
    and B is a 1D Tensor with shape [None] 
    with type int32 elements in [0,D)

    Example:
      A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
          [3,4]]
    """
    linear_index = (tf.shape(A)[1]
                   * tf.range(0,tf.shape(A)[0]))
    linear_A = tf.reshape(A, [-1])
    return tf.gather(linear_A, B + linear_index)

This feels slightly hacky though.

If anyone knows a better (as in clearer or faster) please also leave an answer! (I won't accept my own for a while)

like image 98
Taaam Avatar answered Sep 16 '25 07:09

Taaam


Code for what @Eugene Brevdo said:

def vector_slice(A, B):
    """ Returns values of rows i of A at column B[i]

    where A is a 2D Tensor with shape [None, D]
    and B is a 1D Tensor with shape [None]
    with type int32 elements in [0,D)

    Example:
      A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
          [3,4]]
    """
    B = tf.expand_dims(B, 1)
    range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1)
    ind = tf.concat([range, B], 1)
    return tf.gather_nd(A, ind)
like image 25
hrushikesh Avatar answered Sep 16 '25 07:09

hrushikesh



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!