Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow python : Accessing individual elements in a tensor

This question is with respect to accessing individual elements in a tensor, say [[1,2,3]]. I need to access the inner element [1,2,3] (This can be performed using .eval() or sess.run()) but it takes longer when the size of the tensor is huge)

Is there any method to do the same faster?

Thanks in Advance.

like image 633
cipher42 Avatar asked Feb 02 '16 06:02

cipher42


People also ask

Can you slice tensors?

You can use tf. slice on higher dimensional tensors as well. You can also use tf. strided_slice to extract slices of tensors by 'striding' over the tensor dimensions.

Can you index a tensor?

Single element indexing for a 1-D tensors works mostly as expected. Like R, it is 1-based. Unlike R though, it accepts negative indices for indexing from the end of the array. (In R, negative indices are used to remove elements.)


2 Answers

There are two main ways to access subsets of the elements in a tensor, either of which should work for your example.

  1. Use the indexing operator (based on tf.slice()) to extract a contiguous slice from the tensor.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  output = input[0, :] print sess.run(output)  # ==> [1 2 3] 

    The indexing operator supports many of the same slice specifications as NumPy does.

  2. Use the tf.gather() op to select a non-contiguous slice from the tensor.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  output = tf.gather(input, 0) print sess.run(output)  # ==> [1 2 3]  output = tf.gather(input, [0, 2]) print sess.run(output)  # ==> [[1 2 3] [7 8 9]] 

    Note that tf.gather() only allows you to select whole slices in the 0th dimension (whole rows in the example of a matrix), so you may need to tf.reshape() or tf.transpose() your input to obtain the appropriate elements.

like image 179
mrry Avatar answered Oct 01 '22 18:10

mrry


I hope I understood your question well. You can access elements in a tensor in TensorFlow 2 via .numpy().

import tensorflow as tf t = tf.constant([[1,2,3]])  print(t.numpy()[0][1]) # This will print 2 

>>> 2 
like image 33
Peyman Avatar answered Oct 01 '22 17:10

Peyman