Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch tensor indexing

I am currently working on converting some code from tensorflow to pytorch, I encountered problem with tf.gather func, there is no direct function to convert it in pytorch.

What I am trying to do is basically indexing, I have two tensors, feature tensor shapes of [minibatch, 60, 2] and indexing tensor [minibatch, 8], say like first tensor is tensor A, and the second one is B.

In Tensorflow, it is directly converted with tf.gather(A, B, batch_dims=1)

How do I achieve this in pytorch?

I have tried A[B] indexing. This seems not work

and A[0]B[0] works, but output of shape is [8, 2]

I need the shape of [minibatch, 8, 2]

It will probably work if I stack tensor like [stack, 8, 2] but I have no idea how to do it

tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great

Output shape of [minibatch, 8, 2]

like image 580
Ayden Lee Avatar asked Jan 28 '26 22:01

Ayden Lee


1 Answers

I think you are looking for torch.gather

out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))
like image 159
Shai Avatar answered Jan 30 '26 10:01

Shai



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!