Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Slicing Tensorflow Tensor with Tensor

Tags:

tensorflow

I am trying to use the "advanced", numpy-style slicing that was added in this PR, however I am running into the same issue as the user here:

ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_15' (op: 'StridedSlice') with input shapes: [3,2], [1,2], [1,2], [1].

Namely I would like to do the equivalent of this numpy operation (works in numpy):

A = np.array([[1,2],[3,4],[5,6]]) 
id_rows = np.array([0,2])
A[id_rows]

however this does not work in TF for the error above:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([0,2])
A[id_rows]
like image 793
Alex Rothberg Avatar asked Oct 30 '22 05:10

Alex Rothberg


1 Answers

you are looking for something like this:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([[0],[2]]) #Notice the brackets
out = tf.gather_nd(A,id_rows)
like image 171
vijay m Avatar answered Jan 02 '23 20:01

vijay m