Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead

I wrote below code using PyTorch and ran into the runtime error:

tns = torch.tensor([1,0,1])
tns.mean()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-666-194e5ab56931> in <module>
----> 1 tns.mean()

RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead.

However, if I change the tensor to float, the error goes away:

tns = torch.tensor([1.,0,1])
tns.mean()
---------------------------------------------------------------------------
tensor(0.6667)

My question is why the error happens. The data type of the first tensor is int64 instead of Long, why does PyTorch take it as Long?

like image 342
Chris Avatar asked May 15 '26 10:05

Chris


1 Answers

You should change thetorch.tensor([1,0,1]) to torch.Tensor([1,0,1]).

like image 174
YJ Jang Avatar answered May 16 '26 23:05

YJ Jang