Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch select values from the last tensor dimension with indices from another tenor with a smaller dimension

Tags:

pytorch

I have a tensor a with three dimensions. The first dimension corresponds to minibatch size, the second to the sequence length, and the third to the feature dimension. E.g.,

>>> a = torch.arange(1, 13, dtype=torch.float).view(2,2,3)  # Consider the values of a to be random
>>> a
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [10., 11., 12.]]])

I have a second, two-dimensional tensor. Its first dimension corresponds to the minibatch size and its second dimension to the sequence length. It contains values in the range of the indices of the third dimension of a. as third dimension has size 3, so b can contain values 0, 1 or 2. E.g.,

>>> b = torch.LongTensor([[0, 2],[1,0]])
>>> b
tensor([[0, 2],
        [1, 0]])

I want to obtain a tensor c that has the shape of b and contains all the values of a that are referenced by b. In the upper scenario I would like to have:

c = torch.empty(2,2)
c[0,0] = a[0, 0, b[0,0]]
c[1,0] = a[1, 0, b[1,0]]
c[0,1] = a[0, 1, b[0,1]]
c[1,1] = a[1, 1, b[1,1]]

>>> c
tensor([[ 1.,  5.],
        [ 8., 10.]])

How can I create the tensor c fast? Further, I also want c to be differentiable (be able to use .backprob()). I am not too familiar with pytorch, so I am not sure, if a differentiable version of this exists.

As an alternative, instead of c having the same shape as b I could also use a c with the same shape of a, having only zeros, but at the places referenced by b ones. Then I could multiply a and c to obtain a differentiable tensor.

Like follows:

c = torch.zeros(2,2,3, dtype=torch.float)
c[0,0,b[0,0]] = 1
c[1,0,b[1,0]] = 1
c[0,1,b[0,1]] = 1
c[1,1,b[1,1]] = 1

>>> a*c
tensor([[[ 1.,  0.,  0.],
         [ 0.,  5.,  0.]],

        [[ 0.,  8.,  0.],
         [10.,  0.,  0.]]])
like image 996
Skyy2010 Avatar asked Nov 07 '22 19:11

Skyy2010


1 Answers

Lets declare necessary variables first: (notice requires_grad in a's initialization, we will use it to ensure differentiability)

a = torch.arange(1,13,dtype=torch.float32,requires_grad=True).reshape(2,2,3)
b = torch.LongTensor([[0, 2],[1,0]])

Lets reshape a and squash minibatch and sequence dimensions:

temp = a.reshape(-1,3)

so temp now looks like:

tensor([[ 1.,  2.,  3.],
    [ 4.,  5.,  6.],
    [ 7.,  8.,  9.],
    [10., 11., 12.]], grad_fn=<AsStridedBackward>)

Notice now each value of b can be used in each row of temp to get desired output. Now we do:

c = temp[range(len(temp )),b.view(-1)].view(b.size())

Notice how we index temp, range(len(temp )) to select each row and 1D b i.e b.view(-1) to get corresponding columns. Lastly .view(b.size()) brings this array to the same size as b.

If we print c now:

tensor([[ 1.,  6.],
    [ 8., 10.]], grad_fn=<ViewBackward>)

The presence of grad_fn=.. shows that c requires gradient i.e. its differentiable.

like image 115
Shihab Shahriar Khan Avatar answered Dec 01 '22 08:12

Shihab Shahriar Khan