Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Indexing a 3d tensor using a 2d tensor

Tags:

pytorch

I have a 3d tensor, source of shape (bsz x slen1 x nhd) and a 2d tensor, index of shape (bsz x slen2). More specifically, I have:

source = 32 x 20 x 768
index  = 32 x 16

Each value in the index tensor is in between [0, 19] which is the index of the desired vector according to the 2nd dim of the source tensor.

After indexing, I am expecting an output tensor of shape, 32 x 16 x 768.

Currently I am doing this:

bsz, _, nhid = source.size()
_, slen = index.size()

source = source.reshape(-1, nhid)
source = source[index.reshape(-1), :]
source = source.reshape(bsz, slen, nhid)

So, I am converting the 3d source tensor to a 2d tensor and 2d indexing tensor to a 1d tensor and then perform the indexing. Is this correct?

Is there any better way to do it?

Update

I checked that my code is not giving the expected result. To explain what I want, I am providing the following code snippet.

source = torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820],
     [ 0.3490, -0.0198,  0.7928]],

    [[-0.0973,  2.3106, -1.8358],
     [-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

index = torch.LongTensor([[0, 1, 2, 3], 
                          [1, 2, 3, 4]])

And I want the output tensor as:

torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820]],

    [[-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])
like image 876
Wasi Ahmad Avatar asked Apr 11 '19 08:04

Wasi Ahmad


People also ask

Can you index a tensor?

Single element indexing for a 1-D tensors works mostly as expected. Like R, it is 1-based. Unlike R though, it accepts negative indices for indexing from the end of the array. (In R, negative indices are used to remove elements.)

How do you define a 2 d tensor?

Two-dimensional tensor is similar to the two-dimensional metrics. A two-dimensional metrics have n number of rows and n number of columns. Similarly, two-dimensional tensor has n rows and n columns also. A gray scalar image is a two-dimensional matrix of pixels.

What is a 3D tensor?

A 3D Tensor (or rank 3 Tensor) is a cube. An array of arrays of arrays, like so: Everything after 3D becomes harder to conceptualize, but let's try.

What is index tensor in PyTorch?

Indexing a Pytorch tensor is similar to that of a Python list. The pytorch tensor indexing is 0 based, i.e, the first element of the array has index 0.


1 Answers

Update:

source[torch.arange(source.shape[0]).unsqueeze(-1), index]

Note that torch.arange(source.shape[0]).unsqueeze(-1) gives:

tensor([[0],
        [1]])  # 2 x 1

and index is:

tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]])  # 2 x 4

The arange indexes the batch dimension while index simultaneously indexes the slen1 dimension. The unsqueeze call adds the extra x 1 dimension to the arange result so that the two can be broadcast together.

like image 84
colesbury Avatar answered Sep 29 '22 07:09

colesbury