Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow getting elements of every row for specific columns

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable

index = tf.Variable([0, 1])

I want to use this index to select columns in each row. In this case, item 0 from first row and item 1 from second row.

If A was a Numpy array then to get the columns of corresponding rows mentioned in index we can do

x = A[np.arange(A.shape[0]), index]

and the result would be

[1, 4]

What is the TensorFlow equivalent operation/operations for this? I know TensorFlow doesn't support many indexing operations. What would be the work around if it cannot be done directly?

like image 491
Kashyap Avatar asked Sep 25 '16 07:09

Kashyap


1 Answers

You can use one hot method to create a one_hot array and use it as a boolean mask to select the indices you'd like.

A = tf.Variable([[1, 2], [3, 4]])
index = tf.Variable([0, 1])

one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool)
output = tf.boolean_mask(A, one_hot_mask)
like image 179
Mehmet Avatar answered Sep 21 '22 01:09

Mehmet