Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

model.cuda() in pytorch

Tags:

pytorch

If I call model.cuda() in pytorch where model is a subclass of nn.Module, and say if I have four GPUs, how it will utilize the four GPUs and how do I know which GPUs that are using?

like image 903
william007 Avatar asked Jan 26 '23 03:01

william007


2 Answers

If you have a custom module derived from nn.Module after model.cuda() all model parameters, (model.parameters() iterator can show you these) will end on your cuda.

To check where are your parameters just print them (cuda:0) in my case:

class M(nn.Module):
    'custom module'
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)

m = M()
m.cuda()
for _ in m.parameters():
    print(_)

# Parameter containing:
# tensor([[-0.0201,  0.0282, -0.0258,  ...,  0.0056,  0.0146,  0.0220],
#         [ 0.0098, -0.0264,  0.0283,  ...,  0.0286, -0.0052,  0.0007],
#         [-0.0036, -0.0045, -0.0227,  ..., -0.0048, -0.0003, -0.0330],
#         ...,
#         [ 0.0217, -0.0008,  0.0029,  ..., -0.0213,  0.0005,  0.0050],
#         [-0.0050,  0.0320,  0.0013,  ..., -0.0057, -0.0213,  0.0045],
#         [-0.0302,  0.0315,  0.0356,  ...,  0.0259,  0.0166, -0.0114]],
#        device='cuda:0', requires_grad=True)
# Parameter containing:
# tensor([-0.0027, -0.0353, -0.0349, -0.0236, -0.0230,  0.0176, -0.0156,  0.0037,
#          0.0222, -0.0332], device='cuda:0', requires_grad=True) 

You can also specify the device like this:

m.cuda('cuda:0')

With torch.cuda.device_count() you may check how many devices you have.

like image 87
prosti Avatar answered Feb 01 '23 12:02

prosti


To expand on prosti's answer to split your computations among multiple GPUs you should use torch.nn.DataParallel or DistributedDataParallel.

like image 21
Szymon Maszke Avatar answered Feb 01 '23 11:02

Szymon Maszke