I have the tensors:
ids: shape (7000,1) containing indices like [[1],[0],[2],...]
x: shape(7000,3,255)
ids tensor encodes the index of bold marked dimension of x which should be selected.
I want to gather the selected slices in a resulting vector:
result: shape (7000,255)
Background:
I have some scores (shape = (7000,3)) for each of the 3 elements and want only to select the one with the highest score. Therefore, I used the function
ids = torch.argmax(scores,1,True)
giving me the maximum ids. I already tried to do it with gather function:
result = x.gather(1,ids)
but that didn't work.
Here is a solution you may look for
ids = ids.repeat(1, 255).view(-1, 1, 255)
An example as below:
x = torch.arange(24).view(4, 3, 2)
"""
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
"""
ids = torch.randint(0, 3, size=(4, 1))
"""
tensor([[0],
[2],
[0],
[2]])
"""
idx = ids.repeat(1, 2).view(4, 1, 2)
"""
tensor([[[0, 0]],
[[2, 2]],
[[0, 0]],
[[2, 2]]])
"""
torch.gather(x, 1, idx)
"""
tensor([[[ 0, 1]],
[[10, 11]],
[[12, 13]],
[[22, 23]]])
"""
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With