I am looking for a way to save a pytorch model, and load it without the model definition. By this I mean that I want to save my model including model definition.
For example, I would like to have two scripts. The first would define, train, and save the model. The second would load and predict the model without including the model definition.
The method using torch.save(), torch.load()
requires me to include the model definition in the prediction script, but I want to find a way to load a model without redefining it in the script.
You can attempt to export your model to TorchScript using tracing. This has limitations. Due to the way PyTorch constructs the model's computation graph on the fly, if you have any control-flow in your model then the exported model may not completely represent your python module. TorchScript is only supported in PyTorch >= 1.0.0, though I would recommend using the latest version possible.
For example, a model without any conditional behavior is fine
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
self.bn2 = nn.BatchNorm2d(20)
self.fc = nn.Linear(20 * 4 * 4, 2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn2(x)
x = self.fc(x.flatten(1))
return x
We can export this as follows
from torch import jit
net = Model()
# ... train your model
# put model in the mode you want to export (see bolded comment below)
net.eval()
# print example output
x = torch.ones(1, 3, 16, 16)
print(net(x))
# create TorchScript by tracing the computation graph with an example input
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, 'model.zip')
If successful then we can load our model into a new python script without using Model
.
from torch import jit
net = jit.load('model.zip')
# print example output (should be same as during save)
x = torch.ones(1, 3, 16, 16)
print(net(x))
The loaded model is also trainable, however, the loaded model will only behave in the mode it was exported in. For example, in this case we exported our model in eval()
mode, so using net.train()
on the loaded module will have no effect.
A model like this, which has behavior that changes between passes won't be properly exported. Only the code evaluated during jit.trace
will be exported.
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
self.bn2 = nn.BatchNorm2d(20)
self.fca = nn.Linear(20 * 4 * 4, 2)
self.fcb = nn.Linear(20 * 4 * 4, 2)
self.use_a = True
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn2(x)
if self.use_a:
x = self.fca(x.flatten(1))
else:
x = self.fcb(x.flatten(1))
return x
We can still export the model as follows
import torch
from torch import jit
net = Model()
# ... train your model
net.eval()
# print example input
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
# save model
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, "model.ts")
In this case the example outputs are
a: tensor([[-0.0959, 0.0657]], grad_fn=<AddmmBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<AddmmBackward>)
However, loading
import torch
from torch import jit
net = jit.load("model.ts")
# will not match the output from before
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
results in
a: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
Notice that the logic of the branch "a" is not present since net.use_a
was False
when jit.trace
was called.
These limitations can be overcome but require some effort on your end. You can use the scripting functionality to ensure that all the logic is exported.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With