Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save model architecture in PyTorch?

Tags:

pytorch

I know I can save a model by torch.save(model.state_dict(), FILE) or torch.save(model, FILE). But both of them don't save the architecture of model.

So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ? I want to apply different tweaks to my model. Do I have any better way than copying the whole class definition every time and creating a new class if I can't save the architecture of a model?

like image 374
M.Z. Avatar asked Jan 05 '20 00:01

M.Z.


3 Answers

You can refer to this article to understand how to save the classifier. To make a tweaks to a model, what you can do is create a new model which is a child of the existing model.


class newModel( oldModelClass):
    def __init__(self):
        super(newModel, self).__init__()

With this setup, newModel has all the layers as well as the forward function of oldModelClass. If you need to make tweaks, you can define new layers in the __init__ function and then write a new forward function to define it.

like image 84
Roshan Santhosh Avatar answered Oct 19 '22 10:10

Roshan Santhosh


Saving all the parameters (state_dict) and all the Modules is not enough, since there are operations that manipulates the tensors, but are only reflected in the actual code of the specific implementation (e.g., reshapeing in ResNet).

Furthermore, the network might not have a fixed and pre-determined compute graph: You can think of a network that has branching or a loop (recurrence).

Therefore, you must save the actual code.

Alternatively, if there are no branches/loops in the net, you may save the computation graph, see, e.g., this post.

You should also consider exporting your model using onnx and have a representation that captures both the trained weights as well as the computation graph.

like image 4
Shai Avatar answered Oct 19 '22 08:10

Shai


Regarding the actual question:

So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ?

The answer is: You cannot

Is there any way to load a trained model without declaring the class definition before ? I want the model architecture as well as parameters to be loaded.

no, you have to load the class definition before, this is a python pickling limitation.

https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11

Though, there are other options (probably you have already seen most of those) that are listed at this PyTorch post:

https://pytorch.org/tutorials/beginner/saving_loading_models.html

like image 1
Xxxo Avatar answered Oct 19 '22 08:10

Xxxo