A similar question was already asked here, but I think the solution is not suited for my case.
I just wonder why it is not possible to do a torch.scatter operation, where my index tensor is bigger than my value tensor. In my case I have duplicate indices, e.g. the following value tensor a and the index tensor idx:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
a.scatter(-1, idx, 1) returns:
RuntimeError: Expected index [2, 5] to be smaller than self [2, 4] apart from dimension 1 and to be smaller size than src [2, 4]
Is there another way to achieve this?
Not a solution, but a workaround:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
rows = torch.arange(0, a.size(0))[:,None]
n_col = idx.size(1)
a[rows.repeat(1, n_col), idx] = 1
rows.repeat(1, n_col) gives the row index to the corresponding column index in idx.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With