Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: How to index a tensor using 2D-index like in numpy

I would like to do the following numpy code in Tensorflow:

input = np.array([[1,2,3]
                  [4,5,6]
                  [7,8,9]])
index1 = [0,1,2]
index2 = [2,2,0]
output = input[index1, index2]
>> output
[3,6,7]

given an input such as:

input = tf.constant([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])

I have tried the following, but seems like an overshooting:

index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2
output = tf.gather(tf.reshape(input, [-1]), index3)
sess = tf.Session()
sess.run(output)
>> [3,6,7]

This works only because my first index is conveniently [0,1,2] but wouldn't be doable for [0,0,2] for example (besides looking really long and ugly).

Would you have any easier syntax, something more tensoric/pythonic?

like image 777
ylnor Avatar asked Mar 31 '17 16:03

ylnor


People also ask

Can you index a tensor?

Single element indexing for a 1-D tensors works mostly as expected. Like R, it is 1-based. Unlike R though, it accepts negative indices for indexing from the end of the array. (In R, negative indices are used to remove elements.)

Can you use NumPy on tensor?

Numpy np. array can be used to create tensor of different dimensions such as 1D, 2D, 3D etc. A vector is a 1D tensor, a matrix is a 2D tensor.

How do you pass a tensor to a NumPy call?

To convert back from tensor to numpy array you can simply run . eval() on the transformed tensor.

How do you access the element in a tensor?

We use Indexing and Slicing to access the values of a tensor. Indexing is used to access the value of a single element of the tensor, whereasSlicing is used to access the values of a sequence of elements. We use the assignment operator to modify the values of a tensor.


1 Answers

You can do it using tf.gather_nd(tf.gather_nd official doc) as follows:

import tensorflow as tf
inp = tf.constant([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0])))
sess = tf.Session()
sess.run(res)

The result is array([3, 6, 7])

like image 178
Miriam Farber Avatar answered Oct 22 '22 10:10

Miriam Farber