Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter data in pytorch tensor

Tags:

python

pytorch

I have a tensor X like [0.1, 0.5, -1.0, 0, 1.2, 0], and I want to implement a function called filter_positive(), it can filter the positive data into a new tensor and return the index of the original tensor. For example:

new_tensor, index = filter_positive(X)

new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]

How can I implement this function most efficiently in pytorch?

like image 216
dodolong Avatar asked Aug 20 '19 08:08

dodolong


People also ask

What is detach () in PyTorch?

detach() method in PyTorch is used to separate a tensor from the computational graph by returning a new tensor that doesn't require a gradient.

How do you check a torch tensor in Dtype?

We can access the data type of a tensor using the ". dtype" attribute of the tensor. It returns the data type of the tensor.

How do you compare two tensors in PyTorch?

We can compare two tensors by using the torch. eq() method. This method compares the corresponding elements of tensors. It has to return rue at each location where both tensors have equal value else it will return false.


1 Answers

Take a look at torch.nonzero which is roughly equivalent to np.where. It translates a binary mask to indices:

>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)

>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
        [1],
        [3],
        [4],
        [5]])

>>> X[indices]
tensor([[0.1000],
        [0.5000],
        [0.0000],
        [1.2000],
        [0.0000]])

A solution would then be to write:

mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)
like image 135
nemo Avatar answered Sep 23 '22 10:09

nemo