I'm new to Pytorch and I encounter this error:
x.gather(1, c)
RuntimeError: Invalid index in gather at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:457
Here is some informations about the tensors:
print(x.size())
print(c.size())
print(type(x))
print(type(c))
torch.Size([128, 2])
torch.Size([128, 1])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
x is filled with float values and c with integers, could it be the problem?
This simply means your index tensor c
has invalid indices.
For example, the following index tensor is valid:
x = torch.tensor([
[5, 9, 1],
[3, 2, 8],
[7, 4, 0]
])
c = torch.tensor([
[0, 0, 0],
[1, 2, 0],
[2, 2, 1]
])
x.gather(1, c)
>>>tensor([[5, 5, 5],
[2, 8, 3],
[0, 0, 4]])
However, the following index tensor is invalid:
c = torch.tensor([
[0, 0, 0],
[1, 2, 0],
[2, 2, 3]
])
And it gives the exception you mention
RuntimeError: Invalid index in gather
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With