Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Indexing the max elements in a multidimensional tensor in PyTorch

I'm trying to index the maximum elements along the last dimension in a multidimensional tensor. For example, say I have a tensor

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

Here idx stores the maximum indices, which may look something like

>>>> A
tensor([[[ 1.0503,  0.4448,  1.8663],
     [ 0.8627,  0.0685,  1.4241]],

    [[ 1.2924,  0.2456,  0.1764],
     [ 1.3777,  0.9401,  1.4637]],

    [[ 0.5235,  0.4550,  0.2476],
     [ 0.7823,  0.3004,  0.7792]],

    [[ 1.9384,  0.3291,  0.7914],
     [ 0.5211,  0.1320,  0.6330]],

    [[ 0.3292,  0.9086,  0.0078],
     [ 1.3612,  0.0610,  0.4023]]])
>>>> idx
tensor([[ 2,  2],
    [ 0,  2],
    [ 0,  0],
    [ 0,  2],
    [ 1,  0]])

I want to be able to access these indices and assign to another tensor based on them. Meaning I want to be able to do

B = torch.new_zeros(A.size())
B[idx] = A[idx]

where B is 0 everywhere except where A is maximum along the last dimension. That is B should store

>>>>B
tensor([[[ 0,  0,  1.8663],
     [ 0,  0,  1.4241]],

    [[ 1.2924,  0,  0],
     [ 0,  0,  1.4637]],

    [[ 0.5235,  0,  0],
     [ 0.7823,  0,  0]],

    [[ 1.9384,  0,  0],
     [ 0,  0,  0.6330]],

    [[ 0,  0.9086,  0],
     [ 1.3612,  0,  0]]])

This is proving to be much more difficult than I expected, as the idx does not index the array A properly. Thus far I have been unable to find a vectorized solution to use idx to index A.

Is there a good vectorized way to do this?

like image 219
user395788 Avatar asked Jan 05 '19 23:01

user395788


1 Answers

You can use torch.meshgrid to create an index tuple:

>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]

Note that you can also mimic meshgrid via (for the specific case of 3D):

>>> index_tuple = (
...     torch.arange(A.size(0))[:, None],
...     torch.arange(A.size(1))[None, :],
...     idx
... )

Bit more explanation:
We will have the indices something like this:

In [173]: idx 
Out[173]: 
tensor([[2, 1],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 2]])

From this, we want to go to three indices (since our tensor is 3D, we need three numbers to retrieve each element). Basically we want to build a grid in the first two dimensions, as shown below. (And that's why we use meshgrid).

In [174]: A[0, 0, 2], A[0, 1, 1]  
Out[174]: (tensor(0.6288), tensor(-0.3070))

In [175]: A[1, 0, 2], A[1, 1, 0]  
Out[175]: (tensor(1.7085), tensor(0.7818))

In [176]: A[2, 0, 2], A[2, 1, 1]  
Out[176]: (tensor(0.4823), tensor(1.1199))

In [177]: A[3, 0, 2], A[3, 1, 2]    
Out[177]: (tensor(1.6903), tensor(1.0800))

In [178]: A[4, 0, 2], A[4, 1, 2]          
Out[178]: (tensor(0.9138), tensor(0.1779))

In the above 5 lines, the first two numbers in the indices are basically the grid that we build using meshgrid and the third number is coming from idx.

i.e. the first two numbers form a grid.

 (0, 0) (0, 1)
 (1, 0) (1, 1)
 (2, 0) (2, 1)
 (3, 0) (3, 1)
 (4, 0) (4, 1)
like image 107
a_guest Avatar answered Nov 14 '22 21:11

a_guest