my pytorch code:
import torch
x = torch.tensor([[0.3992, 0.2908, 0.9004, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
print(x.shape)
mask = torch.tensor([[False, False, True, False, True],
[ True, True, True, False, False]])
print(mask.shape)
y = torch.tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
print(y.shape)
y.masked_scatter_(mask, x)
print(y)
result is:
torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 5])
tensor([[0.0000, 0.0000, 0.3992, 0.0000, 0.2908],
[0.9004, 0.4850, 0.6004, 0.0000, 0.0000]])
i think the result answer is:
tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
[0.5375, 0.9006, 0.6797, 0.0000, 0.0000]])
my pytorch version is pytorch1.4
You are right, this is confusing and there is virtually no documentation.
However, the way scatter works (as you have discovered) is that the ith True in a row is given the ith value from the source. So not the value corresponding to the position of the True.
Luckily what you are trying to do can easily be achieved using the normal indexing notation:
>>> y[mask] = x[mask]
>>> y
tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
[0.5735, 0.9006, 0.6797, 0.0000, 0.0000]])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With