Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch tensor indexing: How to gather rows by tensor containing indices

I have the tensors:

ids: shape (7000,1) containing indices like [[1],[0],[2],...]

x: shape(7000,3,255)

ids tensor encodes the index of bold marked dimension of x which should be selected. I want to gather the selected slices in a resulting vector:

result: shape (7000,255)

Background:

I have some scores (shape = (7000,3)) for each of the 3 elements and want only to select the one with the highest score. Therefore, I used the function

ids = torch.argmax(scores,1,True)

giving me the maximum ids. I already tried to do it with gather function:

result = x.gather(1,ids)

but that didn't work.

like image 293
Hackster Avatar asked Oct 26 '25 08:10

Hackster


1 Answers

Here is a solution you may look for

ids = ids.repeat(1, 255).view(-1, 1, 255)

An example as below:

x = torch.arange(24).view(4, 3, 2) 
"""
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])
"""
ids = torch.randint(0, 3, size=(4, 1))
"""
tensor([[0],
        [2],
        [0],
        [2]])
"""
idx = ids.repeat(1, 2).view(4, 1, 2) 
"""
tensor([[[0, 0]],

        [[2, 2]],

        [[0, 0]],

        [[2, 2]]])
"""

torch.gather(x, 1, idx) 
"""
tensor([[[ 0,  1]],

        [[10, 11]],

        [[12, 13]],

        [[22, 23]]])
"""
like image 171
David Ng Avatar answered Oct 28 '25 20:10

David Ng



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!