Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to efficiently retrieve the indices of maximum values in a Torch tensor?

Assume to have a torch tensor, for example of the following shape:

x = torch.rand(20, 1, 120, 120)

What I would like now, is to get the indices of the maximum values of each 120x120 matrix. To simplify the problem I would first x.squeeze() to work with shape [20, 120, 120]. I would then like to get torch tensor which is a list of indices with shape [20, 2].

How can I do this fast?

like image 512
Chris Avatar asked Nov 08 '18 16:11

Chris


People also ask

How do you find the index of maximum value of a tensor?

To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.

How do you find the maximum value of a tensor in PyTorch?

We can also use torch. max() to get the maximum values between two Tensors.

Can you index a tensor PyTorch?

Indexing a tensor in the PyTorch C++ API works very similar to the Python API. All index types such as None / ... / integer / boolean / slice / tensor are available in the C++ API, making translation from Python indexing code to C++ very simple.


2 Answers

torch.topk() is what you are looking for. From the docs,

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

Returns the k largest elements of the given input tensor along a given dimension.

  • If dim is not given, the last dimension of the input is chosen.

  • If largest is False then the k smallest elements are returned.

  • A namedtuple of (values, indices) is returned, where the indices are the indices of the elements in the original input tensor.

  • The boolean option sorted if True, will make sure that the returned k elements are themselves sorted

like image 55
tejasvi88 Avatar answered Sep 20 '22 07:09

tejasvi88


If I get you correctly you don't want the values, but the indices. Unfortunately there is no out of the box solution. There exists an argmax() function, but I cannot see how to get it to do exactly what you want.

So here is a small workaround, the efficiency should also be okay since we're just dividing tensors:

n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)

n represents your first dimension, and d the last two dimensions. I take smaller numbers here to show the result. But of course this will also work for n=20 and d=120:

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)

Here is the output for n=4 and d=4:

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])

I hope this is what you wanted to get! :)

Edit:

Here is a slightly modified which might be minimally faster (not much I guess :), but it is a bit simpler and prettier:

Instead of this like before:

m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

The necessary reshaping already done on the argmax values:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)

But as mentioned in the comments. I don't think it is possible to get much more out of it.

One thing you could do, if it is really important for you to get the last possible bit of performance improvement out of it, is implementing this above function as a low-level extension (like in C++) for pytorch.

This would give you just one function you can call for it and would avoid slow python code.

https://pytorch.org/tutorials/advanced/cpp_extension.html

like image 29
MBT Avatar answered Sep 19 '22 07:09

MBT