Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch - RuntimeError: Error(s) in loading state_dict for VGG:

I've trained a model using PyTorch and saved a state dict file. I have loaded the pre-trained model using the code below. I am getting an error message regarding RuntimeError: Error(s) in loading state_dict for VGG:

RuntimeError: Error(s) in loading state_dict for VGG:
    Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 
    Unexpected key(s) in state_dict: "state_dict", "optimizer_state_dict", "globalStep", "train_paths", "test_paths". 

I am following instruction available at this site: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices

Many Thanks

import argparse
import datetime
import glob
import os
import random
import shutil
import time
from os.path import join

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch.optim as optim

from convnet3 import Convnet
from dataset2 import CellsDataset

from convnet3 import Convnet
from VGG import VGG
from dataset2 import CellsDataset
from torchvision import models
from Conv import Conv2d

parser = argparse.ArgumentParser('Predicting hits from pixels')
parser.add_argument('name',type=str,help='Name of experiment')
parser.add_argument('data_dir',type=str,help='Path to data directory containing images and gt.csv')
parser.add_argument('--weight_decay',type=float,default=0.0,help='Weight decay coefficient (something like 10^-5)')
parser.add_argument('--lr',type=float,default=0.0001,help='Learning rate')
args = parser.parse_args()

metadata = pd.read_csv(join(args.data_dir,'gt.csv'))
metadata.set_index('filename', inplace=True)

# create datasets:

dataset = CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
dataset = DataLoader(dataset,num_workers=4,pin_memory=True)
model_path = '/Users/nubstech/Documents/GitHub/CellCountingDirectCount/VGG_model_V1/checkpoints/checkpoint.pth'

class VGG(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG, self).__init__()
        vgg = models.vgg16(pretrained=pretrained)
        # if pretrained:
        vgg.load_state_dict(torch.load(model_path))
        features = list(vgg.features.children())
        self.features4 = nn.Sequential(*features[0:23])


        self.de_pred = nn.Sequential(Conv2d(512, 128, 1, same_padding=True, NL='relu'),
                                     Conv2d(128, 1, 1, same_padding=True, NL='relu'))


    def forward(self, x):
        x = self.features4(x)       
        x = self.de_pred(x)

        return x

model=VGG()
#model.load_state_dict(torch.load(model_path),strict=False)
model.eval()        

#optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

for images, paths in tqdm(dataset):

    targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
    targets = targets.float()

    # code to print training data to a csv file
    #filename=CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
    output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
    preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
    print(preds)
    paths_test = np.array([paths])
    names_preds = np.hstack(paths)
    print(names_preds)                
    df=pd.DataFrame({'Image_Name':names_preds, 'Target':targets.detach(), 'Prediction':preds.detach()})
    print(df) 
    # save image name, targets, and predictions
    df.to_csv(r'model.csv', index=False, mode='a')

Code for saving the state dict

        torch.save({'state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),
                    'globalStep':global_step,
                    'train_paths':dataset_train.files,
                    'test_paths':dataset_test.files},checkpoint_path)
like image 422
Browed1983 Avatar asked Oct 28 '25 02:10

Browed1983


2 Answers

The problem is that what is being saved is not the same as what is expected to be loaded. The code is trying to load only a state_dict; it is saving quite a bit more than that - looks like a state_dict inside another dict with additional info. The load method doesn't have any logic to look inside the dict.

This should work:

import torch, torchvision.models
model = torchvision.models.vgg16()
path = 'test.pth'
torch.save(model.state_dict(), path) # nothing else here
model.load_state_dict(torch.load(path))
like image 115
Alex I Avatar answered Oct 30 '25 09:10

Alex I


Another reason may be that during the training the parallel processing was active (nn.DataParallel()) for the saved model and now the built model is not parallel. An extra module is then added to each key. Or vice versa you are building a model withnn.DataParallel() active but have not trained the saved State_dict in parallel. Here is a link I used to fix the error: https://github.com/bearpaw/pytorch-classification/issues/27

like image 33
Tim Bretschneider Avatar answered Oct 30 '25 09:10

Tim Bretschneider



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!