my pytorch code:
import torch
x = torch.tensor([[0.3992, 0.2908, 0.9004, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
print(x.shape)
mask = torch.tensor([[False, False, True, False, True],
[ True, True, True, False, False]])
print(mask.shape)
y = torch.tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
print(y.shape)
y.masked_scatter_(mask, x)
print(y)
result is:
torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 5])
tensor([[0.0000, 0.0000, 0.3992, 0.0000, 0.2908],
[0.9004, 0.4850, 0.6004, 0.0000, 0.0000]])
i think the result answer is:
tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
[0.5375, 0.9006, 0.6797, 0.0000, 0.0000]])
my pytorch version is pytorch1.4
You are right, this is confusing and there is virtually no documentation.
However, the way scatter works (as you have discovered) is that the i
th True
in a row is given the i
th value from the source. So not the value corresponding to the position of the True
.
Luckily what you are trying to do can easily be achieved using the normal indexing notation:
>>> y[mask] = x[mask]
>>> y
tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
[0.5735, 0.9006, 0.6797, 0.0000, 0.0000]])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With