Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How Pytorch Tensor get the index of specific value

Tags:

python

pytorch

In python list, we can use list.index(somevalue). How can pytorch do this?
For example:

    a=[1,2,3]     print(a.index(2)) 

Then, 1 will be output. How can a pytorch tensor do this without converting it to a python list?

like image 792
Han Bing Avatar asked Dec 18 '17 06:12

Han Bing


People also ask

How do you find the tensor index?

To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.

What is index in tensor?

In mathematics and mathematical physics, raising and lowering indices are operations on tensors which change their type. Raising and lowering indices are a form of index manipulation in tensor expressions.


2 Answers

I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

t = torch.Tensor([1, 2, 3]) print ((t == 2).nonzero(as_tuple=True)[0]) 

This piece of code returns

1

[torch.LongTensor of size 1x1]

like image 183
Manuel Lagunas Avatar answered Sep 22 '22 04:09

Manuel Lagunas


For multidimensional tensors you can do:

(tensor == target_value).nonzero(as_tuple=True) 

The resulting tensor will be of shape number_of_matches x tensor_dimension. For example, say tensor is a 3 x 4 tensor (that means the dimension is 2), the result will be a 2D-tensor with the indexes for the matches in the rows.

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]]) (tensor == 2).nonzero(as_tuple=False) >>> tensor([[0, 1],         [0, 2],         [1, 2]]) 
like image 24
dopexxx Avatar answered Sep 24 '22 04:09

dopexxx