Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Indexing a multi-dimensional tensor with a tensor in PyTorch

Tags:

tensor

pytorch

I have the following code:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

I have a multi-dimensional index b and want to use it to select a single cell in a. If b wasn't a tensor, I could do:

a[1,1,1,1]

Which returns the correct cell, but:

a[b]

Doesn't work, because it just selects a[1] four times.

How can I do this? Thanks

like image 797
Chum-Chum Scarecrows Avatar asked Aug 30 '18 08:08

Chum-Chum Scarecrows


2 Answers

You can split b into 4 using chunk, and then use the chunked b to index the specific element you want:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

What's nice about it is that it can be easily generalized to any dimension of a, you just need to make number of chucks equal the dimension of a.

like image 67
Shai Avatar answered Oct 25 '22 15:10

Shai


A more elegant (and simpler) solution might be to simply cast b as a tuple:

a[tuple(b)]
Out[10]: tensor(5.)

I was curious to see how this works with "regular" numpy, and found a related article explaining this quite well here.

like image 10
dennlinger Avatar answered Oct 25 '22 15:10

dennlinger