Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 1

Tags:

python

pytorch

I've got a snippet of python code for training a model. The problem is that after running:

loaded_state = torch.load(model_path+seq_to_seq_test_model_fname)

to load a pretrained model,I'm getting:

  Traceback (most recent call last):
  File "img_to_text.py", line 480, in <module>
    main()
  File "img_to_text.py", line 475, in main
    r = setup_test()
  File "img_to_text.py", line 259, in setup_test
    s2s_data = s2s.setup_test()
  File "/media/ahrzb/datasets/notebooks/mzh/SemStyle/semstyle/code/seq2seq_pytorch.py", line 220, in setup_test
    loaded_state= torch.load(model_path+seq_to_seq_test_model_fname)
  File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 358, in load
    return _load(f, map_location, pickle_module)
  File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 542, in _load
    result = unpickler.load()
  File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 508, in persistent_load
    data_type(size), location)
  File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 372, in restore_location
    return default_restore_location(storage, location)
  File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 104, in default_restore_location
    result = fn(storage, location)
  File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 85, in _cuda_deserialize
    device, torch.cuda.device_count()))

I think this is because they have trained the model on two GPUs and I need to load it in one GPU. I changed this line:

loaded_state = torch.load(model_path+seq_to_seq_test_model_fname) 

to

loaded_state = torch.load(model_path+seq_to_seq_test_model_fname, map_location={'cuda:1': 'cuda:0'} ) 

in order to map data of cuda 1 to cuda 0 but it did not work.

like image 857
Marzi Heidari Avatar asked Nov 07 '18 09:11

Marzi Heidari


1 Answers

I just figured it out:

 loaded_state = torch.load(model_path+seq_to_seq_test_model_fname,map_location='cuda:0')

is the solution

like image 161
Marzi Heidari Avatar answered Oct 21 '22 21:10

Marzi Heidari