Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving PyTorch model with no access to model class code

How can I save a PyTorch model without a need for the model class to be defined somewhere?


Disclaimer:

In Best way to save a trained model in PyTorch?, there are no solutions (or a working solution) for saving the model without access to the model class code.

like image 227
Michael D Avatar asked Dec 11 '19 14:12

Michael D


2 Answers

If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.

I think the simplest thing is to use trace = torch.jit.trace(model, typical_input) and then torch.jit.save(trace, path). You can then load the traced model with torch.jit.load(path).

Here's a really simple example. We make two files:

train.py :

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
    print(model(x))
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")

infer.py :

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
    print(loaded_trace(x))

Running these sequentially gives results:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).

TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.

like image 157
nlml Avatar answered Nov 02 '22 02:11

nlml


I recomend you to convert you pytorch model to onnx and save it. Probably its best way to store model without an access to the class.

like image 30
Dima Komarovski Avatar answered Nov 02 '22 03:11

Dima Komarovski