Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

torch.masked_scatter result did not meet expectations

Tags:

pytorch

my pytorch code:

import torch
x = torch.tensor([[0.3992, 0.2908, 0.9004, 0.4850, 0.6004],
            [0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
print(x.shape)
mask = torch.tensor([[False, False,  True, False,  True],
            [ True,  True,  True, False, False]])
print(mask.shape)
y = torch.tensor([[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.]])
print(y.shape)
y.masked_scatter_(mask, x)
print(y)

result is:

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 5])
tensor([[0.0000, 0.0000, 0.3992, 0.0000, 0.2908],
        [0.9004, 0.4850, 0.6004, 0.0000, 0.0000]])

i think the result answer is:

tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
        [0.5375, 0.9006, 0.6797, 0.0000, 0.0000]])

my pytorch version is pytorch1.4

like image 870
Chuanhua Yang Avatar asked Sep 03 '25 16:09

Chuanhua Yang


1 Answers

You are right, this is confusing and there is virtually no documentation.

However, the way scatter works (as you have discovered) is that the ith True in a row is given the ith value from the source. So not the value corresponding to the position of the True.

Luckily what you are trying to do can easily be achieved using the normal indexing notation:

>>> y[mask] = x[mask]
>>> y
tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
        [0.5735, 0.9006, 0.6797, 0.0000, 0.0000]])
like image 198
Thomas Ahle Avatar answered Sep 05 '25 16:09

Thomas Ahle



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!