Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to check if model in on CUDA?

Tags:

pytorch

I would like to check if model is on CUDA. How to do that?

import torch
import torchvision
model = torchvision.models.resnet18()
model.to('cuda')

Seams that model.is_cuda() is not working.

like image 777
prosti Avatar asked Dec 30 '22 20:12

prosti


1 Answers

This code should do it:

import torch
import torchvision
model = torchvision.models.resnet18()
model.to('cuda')
next(model.parameters()).is_cuda

Out:

True

Note there is no is_cuda() method inside nn.Module. Also note model.to('cuda') is the same as model.cuda() and both are inplace.

On the other hand moving the data.to('cuda') is not inplace and you typically call:

data = data.to('cuda')

to move the data to CUDA.

like image 136
prosti Avatar answered Jan 31 '23 01:01

prosti