Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Error, RuntimeError: expected scalar type Long but found Double

I have run into the following error while training a BERT classifier. The

type(b_input_mask) = type(b_labels) = torch.Tensor      

type(b_labels[i]) = tensor(1., dtype=torch.float64)

type(b_input_masks[i]) = class'torch.Tensor'

What could be the possible data type error here since I have not typecasted any variable to either long or double?

Thanks in advance! Error Stack Trace

like image 789
Soumya Ranjan Sahoo Avatar asked Sep 08 '25 09:09

Soumya Ranjan Sahoo


1 Answers

In a classification task, the data type for input labels should be Long but you assigned them as float64

type(b_labels[i]) = tensor(1., dtype=torch.float64)

=>

type(b_labels[i]) = tensor(1., dtype=torch.long)
like image 164
Bugface Avatar answered Sep 10 '25 04:09

Bugface