I have a tensor X
like [0.1, 0.5, -1.0, 0, 1.2, 0]
, and I want to implement a function called filter_positive()
, it can filter the positive data into a new tensor and return the index of the original tensor. For example:
new_tensor, index = filter_positive(X)
new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]
How can I implement this function most efficiently in pytorch?
detach() method in PyTorch is used to separate a tensor from the computational graph by returning a new tensor that doesn't require a gradient.
We can access the data type of a tensor using the ". dtype" attribute of the tensor. It returns the data type of the tensor.
We can compare two tensors by using the torch. eq() method. This method compares the corresponding elements of tensors. It has to return rue at each location where both tensors have equal value else it will return false.
Take a look at torch.nonzero
which is roughly equivalent to np.where
. It translates a binary mask to indices:
>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)
>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
[1],
[3],
[4],
[5]])
>>> X[indices]
tensor([[0.1000],
[0.5000],
[0.0000],
[1.2000],
[0.0000]])
A solution would then be to write:
mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With