How do I select certain columns of a 2D tensor in TensorFlow?



As generalized slicing is being worked on in this issue, what would be the best way to achieve an op gathering columns of a 2D tensor (matrix)? For example, for tensor t:

1 2 3 4
5 6 7 8 

and indices [1,3], I would like to get:

2 4
6 8

which is equivalent to numpy t[:, [1,3]].

2 Answers

Meanwhile the gather method has an axis parameter.

import tensorflow as tf
params = tf.constant([[1,2,3],[4,5,6]])
indices = [0,2]
op = tf.gather(params, indices, axis=1)

produces the output

[[1 3]
 [4 6]]
There is a function named tf.nn.embedding_lookup(params, ind) which retrieves the rows of the params tensor.

To achieve what you want, we can first transpose the tensor t from which you want to select certain columns from. Then look up the rows of tf.transpose(t) (columns of t). After the selection, we transpose the result back.

import tensorflow as tf

t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])
ind = tf.constant([0, 2])

result = tf.transpose(tf.nn.embedding_lookup(tf.transpose(t), ind))

with tf.Session() as sess:
