I am looking at Google's example on how to deploy and use a pre-trained Tensorflow graph (model) on Android. This example uses a .pb
file at:
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
which is a link to a file that downloads automatically.
The example shows how to load the .pb
file to a Tensorflow session and use it to perform classification, but it doesn't seem to mention how to generate such a .pb
file, after a graph is trained (e.g., in Python).
Are there any examples on how to do that?
TensorFlow protocol buffer. Since protocol buffers use a structured format when storing data, they can be represented with Python classes. In TensorFlow, the tf. train. Example class represents the protocol buffer used to store data for the input pipeline.
The GraphDef class is an object created by the ProtoBuf library from the definition in tensorflow/core/framework/graph. proto. The protobuf tools parse this text file, and generate the code to load, store, and manipulate graph definitions.
Protocol Buffers (Protobuf) is a free and open-source cross-platform data format used to serialize structured data. It is useful in developing programs to communicate with each other over a network or for storing data.
The Protobuf is a binary transfer format, meaning the data is transmitted as a binary. This improves the speed of transmission more than the raw string because it takes less space and bandwidth. Since the data is compressed, the CPU usage will also be less.
EDIT: The freeze_graph.py
script, which is part of the TensorFlow repository, now serves as a tool that generates a protocol buffer representing a "frozen" trained model, from an existing TensorFlow GraphDef
and a saved checkpoint. It uses the same steps as described below, but it much easier to use.
Currently the process isn't very well documented (and subject to refinement), but the approximate steps are as follows:
tf.Graph
called g_1
.Session.run()
).tf.Graph
called g_2
, create tf.constant()
tensors for each of the variables, using the value of the corresponding numpy array fetched in step 2.Use tf.import_graph_def()
to copy nodes from g_1
into g_2
, and use the input_map
argument to replace each variable in g_1
with the corresponding tf.constant()
tensors created in step 3. You may also want to use input_map
to specify a new input tensor (e.g. replacing an input pipeline with a tf.placeholder()
). Use the return_elements
argument to specify the name of the predicted output tensor.
Call g_2.as_graph_def()
to get a protocol buffer representation of the graph.
(NOTE: The generated graph will have extra nodes in the graph for training. Although it is not part of the public API, you may wish to use the internal graph_util.extract_sub_graph()
function to strip these nodes from the graph.)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With