Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Extracting the top-k value-indices from a 1-D Tensor

Given a 1-D tensor in Torch (torch.Tensor), containing values which can be compared (say floating point), how can we extract the indices of the top-k values in that tensor?

Apart from the brute-force method, I am looking for some API call, that Torch/lua provides, which can perform this task efficiently.

like image 497
dyno8426 Avatar asked Jan 12 '16 17:01

dyno8426


People also ask

How do you find the maximum index of a tensor?

To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.

What does TOPK return?

Returns the k largest elements of the given input tensor along a given dimension. If dim is not given, the last dimension of the input is chosen.

Is Torch TOPK differentiable?

topk is a differentiable operation in pytorch, which makes your entire pipeline differentiable.

How do you make a torch tensor?

To create a tensor with pre-existing data, use torch.tensor() . To create a tensor with specific size, use torch.* tensor creation ops (see Creation Ops). To create a tensor with the same size (and similar types) as another tensor, use torch.*_like tensor creation ops (see Creation Ops).


2 Answers

As of pull request #496 Torch now includes a built-in API named torch.topk. Example:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]

At the time of writing the CPU implementation follows a sort and narrow approach (there are plans to improve it in the future). That being said an optimized GPU implementation for cutorch is currently being reviewed.

like image 59
deltheil Avatar answered Sep 22 '22 04:09

deltheil


You can use topk function.

for example:

import torch

t = torch.tensor([5.7, 1.4, 9.5, 1.6, 6.1, 4.3])

values,indices = t.topk(2)

print(values)
print(indices)

the result:

tensor([9.5000, 6.1000])
tensor([2, 4])
like image 20
ChaosPredictor Avatar answered Sep 23 '22 04:09

ChaosPredictor