Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch RuntimeError: Invalid index in gather

Tags:

I'm new to Pytorch and I encounter this error:

x.gather(1, c)

RuntimeError: Invalid index in gather at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:457

Here is some informations about the tensors:

print(x.size())
print(c.size())
print(type(x))
print(type(c))

torch.Size([128, 2])
torch.Size([128, 1])
<class 'torch.Tensor'>
<class 'torch.Tensor'>

x is filled with float values and c with integers, could it be the problem?

like image 470
Louis Beaumont Avatar asked Jan 26 '19 17:01

Louis Beaumont


1 Answers

This simply means your index tensor c has invalid indices. For example, the following index tensor is valid:

        x = torch.tensor([
        [5, 9, 1],
        [3, 2, 8],
        [7, 4, 0]
    ])
    c = torch.tensor([
        [0, 0, 0],
        [1, 2, 0],
        [2, 2, 1]
    ])
    x.gather(1, c)
>>>tensor([[5, 5, 5],
        [2, 8, 3],
        [0, 0, 4]])

However, the following index tensor is invalid:

c = torch.tensor([
    [0, 0, 0],
    [1, 2, 0],
    [2, 2, 3]
])

And it gives the exception you mention

RuntimeError: Invalid index in gather

like image 160
abdullah.cu Avatar answered Sep 24 '22 01:09

abdullah.cu