Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I load a model in PyTorch without redefining the model?

I am looking for a way to save a pytorch model, and load it without the model definition. By this I mean that I want to save my model including model definition.

For example, I would like to have two scripts. The first would define, train, and save the model. The second would load and predict the model without including the model definition.

The method using torch.save(), torch.load() requires me to include the model definition in the prediction script, but I want to find a way to load a model without redefining it in the script.

like image 713
김수호 Avatar asked Jan 16 '20 16:01

김수호


1 Answers

You can attempt to export your model to TorchScript using tracing. This has limitations. Due to the way PyTorch constructs the model's computation graph on the fly, if you have any control-flow in your model then the exported model may not completely represent your python module. TorchScript is only supported in PyTorch >= 1.0.0, though I would recommend using the latest version possible.

For example, a model without any conditional behavior is fine

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc = nn.Linear(20 * 4 * 4, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn2(x)
        x = self.fc(x.flatten(1))
        return x

We can export this as follows

from torch import jit

net = Model()
# ... train your model

# put model in the mode you want to export (see bolded comment below)
net.eval()

# print example output
x = torch.ones(1, 3, 16, 16)
print(net(x))

# create TorchScript by tracing the computation graph with an example input
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, 'model.zip')

If successful then we can load our model into a new python script without using Model.

from torch import jit
net = jit.load('model.zip')

# print example output (should be same as during save)
x = torch.ones(1, 3, 16, 16)
print(net(x))

The loaded model is also trainable, however, the loaded model will only behave in the mode it was exported in. For example, in this case we exported our model in eval() mode, so using net.train() on the loaded module will have no effect.


Control-flow

A model like this, which has behavior that changes between passes won't be properly exported. Only the code evaluated during jit.trace will be exported.

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(20)
        self.fca = nn.Linear(20 * 4 * 4, 2)
        self.fcb = nn.Linear(20 * 4 * 4, 2)

        self.use_a = True

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn2(x)
        if self.use_a:
            x = self.fca(x.flatten(1))
        else:
            x = self.fcb(x.flatten(1))
        return x

We can still export the model as follows

import torch
from torch import jit

net = Model()
# ... train your model

net.eval()

# print example input
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))

# save model
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, "model.ts")

In this case the example outputs are

a: tensor([[-0.0959,  0.0657]], grad_fn=<AddmmBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<AddmmBackward>)

However, loading

import torch
from torch import jit

net = jit.load("model.ts")

# will not match the output from before
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))

results in

a: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)

Notice that the logic of the branch "a" is not present since net.use_a was False when jit.trace was called.


Scripting

These limitations can be overcome but require some effort on your end. You can use the scripting functionality to ensure that all the logic is exported.

like image 186
jodag Avatar answered Jan 02 '23 09:01

jodag