I know I can save a model by torch.save(model.state_dict(), FILE)
or torch.save(model, FILE)
. But both of them don't save the architecture of model.
So how can we save the architecture of a model in PyTorch like creating a .pb
file in Tensorflow ? I want to apply different tweaks to my model. Do I have any better way than copying the whole class definition every time and creating a new class if I can't save the architecture of a model?
You can refer to this article to understand how to save the classifier. To make a tweaks to a model, what you can do is create a new model which is a child of the existing model.
class newModel( oldModelClass):
def __init__(self):
super(newModel, self).__init__()
With this setup, newModel has all the layers as well as the forward function of oldModelClass
. If you need to make tweaks, you can define new layers in the __init__
function and then write a new forward function to define it.
Saving all the parameters (state_dict
) and all the Modules is not enough, since there are operations that manipulates the tensors, but are only reflected in the actual code of the specific implementation (e.g., reshape
ing in ResNet).
Furthermore, the network might not have a fixed and pre-determined compute graph: You can think of a network that has branching or a loop (recurrence).
Therefore, you must save the actual code.
Alternatively, if there are no branches/loops in the net, you may save the computation graph, see, e.g., this post.
You should also consider exporting your model using onnx
and have a representation that captures both the trained weights as well as the computation graph.
Regarding the actual question:
So how can we save the architecture of a model in PyTorch like creating a .pb file in Tensorflow ?
The answer is: You cannot
Is there any way to load a trained model without declaring the class definition before ? I want the model architecture as well as parameters to be loaded.
no, you have to load the class definition before, this is a python pickling limitation.
https://discuss.pytorch.org/t/how-to-save-load-torch-models/718/11
Though, there are other options (probably you have already seen most of those) that are listed at this PyTorch post:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With