Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why doesn't my simple pytorch network work on GPU device?

I built a simple network from a tutorial and I got this error:

RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #4 'mat1'

Any help? Thank you!

import torch
import torchvision

device = torch.device("cuda:0")
root = '.data/'

dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.out = torch.nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

net = Net()
net.to(device)

for i, (inputs, labels) in enumerate(dataloader):
    inputs.to(device)
    out = net(inputs)
like image 989
Harkonnen Avatar asked Jul 31 '18 05:07

Harkonnen


People also ask

Does PyTorch work on GPU?

PyTorch enables both CPU and GPU computations in research and production, as well as scalable distributed training and performance optimization. Deep learning is a subfield of machine learning, and the libraries PyTorch and TensorFlow are among the most prominent.

Does PyTorch need GPU?

Do You Need Nvidia Gpu For Pytorch? No, you don't need an Nvidia GPU to use Pytorch. You can use a CPU or an AMD GPU. You can use GPUs in your PyTorch code to train neural networks much faster.


1 Answers

TL;DR
This is the fix

inputs = inputs.to(device)  

Why?!
There is a slight difference between torch.nn.Module.to() and torch.Tensor.to(): while Module.to() is an in-place operator, Tensor.to() is not. Therefore

net.to(device)

Changes net itself and moves it to device. On the other hand

inputs.to(device)

does not change inputs, but rather returns a copy of inputs that resides on device. To use that "on device" copy, you need to assign it into a variable, hence

inputs = inputs.to(device)
like image 169
Shai Avatar answered Sep 30 '22 21:09

Shai