Given a 1-D tensor in Torch (torch.Tensor
), containing values which can be compared (say floating point), how can we extract the indices of the top-k values in that tensor?
Apart from the brute-force method, I am looking for some API call, that Torch/lua provides, which can perform this task efficiently.
To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.
Returns the k largest elements of the given input tensor along a given dimension. If dim is not given, the last dimension of the input is chosen.
topk is a differentiable operation in pytorch, which makes your entire pipeline differentiable.
To create a tensor with pre-existing data, use torch.tensor() . To create a tensor with specific size, use torch.* tensor creation ops (see Creation Ops). To create a tensor with the same size (and similar types) as another tensor, use torch.*_like tensor creation ops (see Creation Ops).
As of pull request #496 Torch now includes a built-in API named torch.topk
. Example:
> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}
-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
1
2
3
[torch.DoubleTensor of size 3]
-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
2
4
6
[torch.LongTensor of size 3]
-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
9
8
7
[torch.DoubleTensor of size 3]
At the time of writing the CPU implementation follows a sort and narrow approach (there are plans to improve it in the future). That being said an optimized GPU implementation for cutorch is currently being reviewed.
You can use topk function.
for example:
import torch
t = torch.tensor([5.7, 1.4, 9.5, 1.6, 6.1, 4.3])
values,indices = t.topk(2)
print(values)
print(indices)
the result:
tensor([9.5000, 6.1000])
tensor([2, 4])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With