Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed

Tags:

pytorch

torch

I’m trying to create a basic binary classifier in Pytorch that classifies whether my player plays on the right or the left side in the game Pong. The input is an 1x42x42 image and the label is my player's side (right = 1 or left = 2). The code:

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

net = Net(42 * 42, 100, 2)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_net = torch.optim.Adam(net.parameters(), 0.001)
net.train()

while True:
    state = get_game_img()
    state = torch.from_numpy(state)

    # right = 1, left = 2
    current_side = get_player_side()
    target = torch.LongTensor(current_side)
    x = Variable(state.view(-1, 42 * 42))
    y = Variable(target)
    optimizer_net.zero_grad()
    y_pred = net(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()

The error I get:

  File "train.py", line 109, in train
    loss = criterion(y_pred, y)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 321, in forward
    self.weight, self.size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 533, in cross_entropy
    return nll_loss(log_softmax(input), target, weight, size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 501, in nll_loss
    return f(input, target)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
    output, *self.additional_args)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /py/conda-bld/pytorch_1493676237139/work/torch/lib/THNN/generic/ClassNLLCriterion.c:57
like image 826
Shani Gamrian Avatar asked Aug 19 '17 08:08

Shani Gamrian


2 Answers

For most of deeplearning library, target(or label) should start from 0.

It means that your target should be in the range of [0,n) with n-classes.

like image 104
Jing Avatar answered Oct 10 '22 09:10

Jing


It looks like PyTorch expect to get zero-based labels (0/1 in your case) and you probably feed it with one-based labels (1/2)

like image 29
yossiB Avatar answered Oct 10 '22 11:10

yossiB