For example, I want to get the indices of elements valued 0 and 2 in tensor a
. These values, (0 and 2) are stored in tensor b
. I have devised a pythonic way to do so (shown below), but I don't think list comprehensions are optimized to run on GPU, or maybe there is a more PyTorchy way to do it that I am unaware of.
import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()
>>>> tensor([[0],
[2],
[5],
[6]])
Any other suggestions or is this an acceptable way?
We can index a Tensor with another Tensor and sometimes we can successfully index a Tensor with a NumPy array. The following code works for some dims : (2, 3) (2, 3, 2)
We use the torch. numel() function to find the total number of elements in the tensor.
To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element.
Practical Data Science using Python A PyTorch tensor is homogenous, i.e., all the elements of a tensor are of the same data type. We can access the data type of a tensor using the ". dtype" attribute of the tensor. It returns the data type of the tensor.
Here's a more efficient way to do it (as suggested in the link posted by jodag in comments...):
(a[..., None] == b).any(-1).nonzero()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With