Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch: AttributeError: 'function' object has no attribute 'copy'

I am trying to load a model state_dict I trained on Google Colab GPU, here is my code to load the model:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.load_state_dict(copy.deepcopy(torch.load("./models/model.pth",device)))
model = model.to(device)
model.eval()

Here is the error:

state_dict = state_dict.copy()

AttributeError: 'function' object has no attribute 'copy'

Pytorch :

>>> import torch
>>> print (torch.__version__)
1.4.0
>>> import torchvision
>>> print (torchvision.__version__)
0.5.0

Please help I have searched everywhere to no avail

[full error details][1] https://i.stack.imgur.com/s22DL.png

like image 439
muuka7 Avatar asked Apr 16 '20 04:04

muuka7


2 Answers

I am guessing this is what you did by mistake. You saved the function

torch.save(model.state_dict, 'model_state.pth')

instead of the state_dict()

torch.save(model.state_dict(), 'model_state.pth')

Otherwise, everything should work as expected. (I tested the following code on Colab)

Replace model.state_dict() with model.state_dict to reproduce error

import copy
model = TheModelClass()
torch.save(model.state_dict(), 'model_state.pth')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(copy.deepcopy(torch.load("model_state.pth",device)))
like image 175
cozek Avatar answered Nov 14 '22 15:11

cozek


because you saved your model

torch.save(model.state_dict, 'model_state.pth')
instead of

torch.save(model.state_dict(), 'model_state.pth')

as result you saved function pointer of your model. for this problem you must load your data like this:

model.load_state_dict(copy.deepcopy(torch.load("./models/model.pth",device)()))

instead of

model.load_state_dict(copy.deepcopy(torch.load("./models/model.pth",device)))

from torch.load("./models/model.pth",device) you can see your model layer details and gain other good data.

like image 24
MRA_7718 Avatar answered Nov 14 '22 13:11

MRA_7718