Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I select certain columns of a 2D tensor in TensorFlow?

Tags:

tensorflow

As generalized slicing is being worked on in this issue, what would be the best way to achieve an op gathering columns of a 2D tensor (matrix)? For example, for tensor t:

1 2 3 4
5 6 7 8 

and indices [1,3], I would like to get:

2 4
6 8

which is equivalent to numpy t[:, [1,3]].

like image 454
Andrzej Pronobis Avatar asked Jun 07 '16 05:06

Andrzej Pronobis


People also ask

How do you access the element in a tensor?

To access elements from a 3-D tensor Slicing can be used. Slicing means selecting the elements present in the tensor by using “:” slice operator. We can slice the elements by using the index of that particular element.

Can you slice tensors?

You can use tf. slice on higher dimensional tensors as well. You can also use tf. strided_slice to extract slices of tensors by 'striding' over the tensor dimensions.

Is it possible to create a tensor with elements of different data types?

Data TypesIt is not possible to have a Tensor with more than one data type.

Can tensors have strings?

# Tensors can be strings, too here is a scalar string.


2 Answers

Meanwhile the gather method has an axis parameter.

import tensorflow as tf
params = tf.constant([[1,2,3],[4,5,6]])
indices = [0,2]
op = tf.gather(params, indices, axis=1)

produces the output

[[1 3]
 [4 6]]
like image 162
AlexConfused Avatar answered Sep 27 '22 18:09

AlexConfused


There is a function named tf.nn.embedding_lookup(params, ind) which retrieves the rows of the params tensor.

To achieve what you want, we can first transpose the tensor t from which you want to select certain columns from. Then look up the rows of tf.transpose(t) (columns of t). After the selection, we transpose the result back.

import tensorflow as tf


t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])
ind = tf.constant([0, 2])

result = tf.transpose(tf.nn.embedding_lookup(tf.transpose(t), ind))

with tf.Session() as sess:
    print(sess.run(result))
like image 21
lucky6qi Avatar answered Sep 27 '22 18:09

lucky6qi