I'm trying to load the model using this tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference . Unfortunately I'm very beginner and I face some problems.
I have created checkpoint:
checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),'loss': loss}
torch.save(checkpoint, 'checkpoint.pth')
Then I wrote class for my network and I wanted to load the file:
class Network(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(9216, 4096)
self.fc2 = nn.Linear(4096, 1000)
self.fc3 = nn.Linear(1000, 102)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
x = log(F.softmax(x, dim=1))
return x
Like that:
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = Network()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model = load_checkpoint('checkpoint.pth')
I got this error (edited to show whole communicate):
RuntimeError: Error(s) in loading state_dict for Network:
Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias".
Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "features.6.weight", "features.6.bias", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias", "classifier.fc3.weight", "classifier.fc3.bias".
This is my model.state_dict().keys()
:
odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight',
'features.3.bias', 'features.6.weight', 'features.6.bias',
'features.8.weight', 'features.8.bias', 'features.10.weight',
'features.10.bias', 'classifier.fc1.weight', 'classifier.fc1.bias',
'classifier.fc2.weight', 'classifier.fc2.bias', 'classifier.fc3.weight',
'classifier.fc3.bias'])
This is my model:
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
((classifier): Sequential(
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=4096, out_features=1000, bias=True)
(relu2): ReLU()
(fc3): Linear(in_features=1000, out_features=102, bias=True)
(output): LogSoftmax()
)
)
It's my first network ever and I'm blundering along. Thanks for steering me into right direction!
A state_dict is an integral entity if you are interested in saving or loading models from PyTorch. Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.
So your Network
is essentially the classifier
part of AlexNet
and you're looking to load pretrained AlexNet
weights into it. The problem is that the keys in state_dict
are "fully qualified", which means that if you look at your network as a tree of nested modules, a key is just a list of modules in each branch, joined with dots like grandparent.parent.child
. You want to
so try
model = Network()
loaded_dict = checkpoint['model_state_dict']
prefix = 'classifier.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in loaded_dict.items()
if k.startswith(prefix)}
model.load_state_dict(adapted_dict)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With