Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to load a checkpoint file in a pytorch model?

In my pytorch model, I'm initializing my model and optimizer like this.

model = MyModelClass(config, shape, x_tr_mean, x_tr,std)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)

And here is the path to my checkpoint file.

checkpoint_file = os.path.join(config.save_dir, "checkpoint.pth")

To load this checkpoint file, I check and see if the checkpoint file exists and then I load it as well as the model and optimizer.

if os.path.exists(checkpoint_file):
    if config.resume:
        torch.load(checkpoint_file)
        model.load_state_dict(torch.load(checkpoint_file))
        optimizer.load_state_dict(torch.load(checkpoint_file))

Also, here's how I'm saving my model and optimizer.

 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter_idx': iter_idx, 'best_va_acc': best_va_acc}, checkpoint_file)

For some reason I keep getting a strange error whenever I run this code.

model.load_state_dict(torch.load(checkpoint_file))
File "/home/Josh/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MyModelClass:
        Missing key(s) in state_dict: "mean", "std", "attribute.weight", "attribute.bias".
        Unexpected key(s) in state_dict: "model", "optimizer", "iter_idx", "best_va_acc"

Does anyone know why I'm getting this error?

like image 478
Josh Susa Avatar asked Feb 13 '19 19:02

Josh Susa


People also ask

How do I use a PT file?

pt file and load it into your model, you could do the following. file = "model.pt" model = your_model() model. load_state_dict(torch. load(file)) # this will automatically load the file and load the parameters into the model.

How do I load a PyTorch checkpoint Dictionary?

A common PyTorch convention is to save these checkpoints using the .tar file extension. To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load(). From here, you can easily access the saved items by simply querying the dictionary as you would expect.

How do I save multiple checkpoints in PyTorch?

To save multiple checkpoints, you must organize them in a dictionary and use torch.save () to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension. To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load ().

How to load models in PyTorch?

A common PyTorch convention is to save these checkpoints using the .tar file extension. To load the models, first initialize the models and optimizers, then load the dictionary locally using torch.load(). From here, you can easily

How do I load a Python dictionary into PyTorch?

A common PyTorch convention is to save these checkpoints using the .tar file extension. To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load (). From here, you can easily access the saved items by simply querying the dictionary as you would expect.


Video Answer


1 Answers

You saved the model parameters in a dictionary. You're supposed to use the keys, that you used while saving earlier, to load the model checkpoint and state_dicts like this:

if os.path.exists(checkpoint_file):
    if config.resume:
        checkpoint = torch.load(checkpoint_file)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

You can check the official tutorial on PyTorch website for more info.

like image 115
kHarshit Avatar answered Oct 30 '22 16:10

kHarshit