Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can we convert a .pth model into .pb file?

I have already got the complete model by using pytorch, however I wanna convert the .pth file into .pb, which could be used in Tensorflow. Does anyone have some ideas?

like image 475
Rafael Avatar asked Dec 23 '19 04:12

Rafael


People also ask

Can you convert PyTorch model to Tensorflow?

ONNX stands for an Open Neural Network Exchange is a way of easily porting models among different frameworks available like Pytorch, Tensorflow, Keras, Cafee2, CoreML. Most of these frameworks now support ONNX format.

What is .PB file in Tensorflow?

The . pb format is the protocol buffer (protobuf) format, and in Tensorflow, this format is used to hold models. Protobufs are a general way to store data by Google that is much nicer to transport, as it compacts the data more efficiently and enforces a structure to the data.


1 Answers

You can use ONNX: Open Neural Network Exchange Format

To convert .pth file to .pb First, you need to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow (PyTorch => ONNX => Tensorflow)

This is an example of MNISTModel to Convert a PyTorch model to Tensorflow using ONNX from onnx/tutorials

Save the trained model to a file

torch.save(model.state_dict(), 'output/mnist.pth')

Load the trained model from file

trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")

Load the ONNX file

model = onnx.load('output/mnist.onnx')

# Import the ONNX model to Tensorflow
tf_rep = prepare(model)

Save the Tensorflow model into a file

tf_rep.export_graph('output/mnist.pb')

AS noted by @tsveti_iko in the comment

NOTE: The prepare() is build-in in the onnx-tf, so you first need to install it through the console like this pip install onnx-tf, then import it in the code like this: import onnx from onnx_tf.backend import prepare and after that you can finally use it as described in the answer.

like image 94
Dishin H Goyani Avatar answered Dec 01 '22 17:12

Dishin H Goyani