Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get indices of elements in tensor a that are present in tensor b

For example, I want to get the indices of elements valued 0 and 2 in tensor a. These values, (0 and 2) are stored in tensor b. I have devised a pythonic way to do so (shown below), but I don't think list comprehensions are optimized to run on GPU, or maybe there is a more PyTorchy way to do it that I am unaware of.

import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()

>>>> tensor([[0],
             [2],
             [5],
             [6]])

Any other suggestions or is this an acceptable way?

like image 474
Douglas De Rizzo Meneghetti Avatar asked Mar 29 '20 17:03

Douglas De Rizzo Meneghetti


People also ask

Can you index a tensor?

We can index a Tensor with another Tensor and sometimes we can successfully index a Tensor with a NumPy array. The following code works for some dims : (2, 3) (2, 3, 2)

How do you get the number of elements in a torch tensor?

We use the torch. numel() function to find the total number of elements in the tensor.

How do you find the index of largest value in a tensor?

To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.

How do you find the data of a tensor?

Practical Data Science using Python A PyTorch tensor is homogenous, i.e., all the elements of a tensor are of the same data type. We can access the data type of a tensor using the ". dtype" attribute of the tensor. It returns the data type of the tensor.


1 Answers

Here's a more efficient way to do it (as suggested in the link posted by jodag in comments...):

(a[..., None] == b).any(-1).nonzero()
like image 61
Andreas K. Avatar answered Sep 30 '22 06:09

Andreas K.