If A
is a TensorFlow variable like so
A = tf.Variable([[1, 2], [3, 4]])
and index
is another variable
index = tf.Variable([0, 1])
I want to use this index to select columns in each row. In this case, item 0 from first row and item 1 from second row.
If A was a Numpy array then to get the columns of corresponding rows mentioned in index we can do
x = A[np.arange(A.shape[0]), index]
and the result would be
[1, 4]
What is the TensorFlow equivalent operation/operations for this? I know TensorFlow doesn't support many indexing operations. What would be the work around if it cannot be done directly?
You can use one hot method to create a one_hot array and use it as a boolean mask to select the indices you'd like.
A = tf.Variable([[1, 2], [3, 4]])
index = tf.Variable([0, 1])
one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool)
output = tf.boolean_mask(A, one_hot_mask)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With