I have the following code:
a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
I have a multi-dimensional index b
and want to use it to select a single cell in a
. If b wasn't a tensor, I could do:
a[1,1,1,1]
Which returns the correct cell, but:
a[b]
Doesn't work, because it just selects a[1]
four times.
How can I do this? Thanks
You can split b
into 4 using chunk
, and then use the chunked b
to index the specific element you want:
>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)] # here's the trick!
Out[24]: tensor([[40, 80, 0]])
What's nice about it is that it can be easily generalized to any dimension of a
, you just need to make number of chucks equal the dimension of a
.
A more elegant (and simpler) solution might be to simply cast b
as a tuple:
a[tuple(b)]
Out[10]: tensor(5.)
I was curious to see how this works with "regular" numpy, and found a related article explaining this quite well here.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With