From what I've gathered so far, there are several different ways of dumping a TensorFlow graph into a file and then loading it into another program, but I haven't been able to find clear examples/information on how they work. What I already know is this:
tf.train.Saver()
and restore them later (source)tf.train.write_graph()
and tf.import_graph_def()
(source)as_graph_def()
to save the model, and for weights/variables, map them into constants (source)However, I haven't been able to clear up several questions regarding these different methods:
tf.train.write_graph()
, are the weights/variables saved as well?tf.import_graph_def()
?as_graph_def()
/.ckpt/.pb?In short, what I'm looking for is a method to save both a graph (as in, the various operations and such) and its weights/variables into a file, which can then be used to load the graph and weights into another program, for use (not necessarily continuing/retraining).
Documentation about this topic isn't very straightforward, so any answers/information would be greatly appreciated.
Loading resnet First I download the inception_resnet_v2.py file. This file allows us to load the network structure into TF. If it's not in the same path as your current path, you need to add its folder to your path. Next we can load the saved weights from the pretrained model.
The . pb format is the protocol buffer (protobuf) format, and in Tensorflow, this format is used to hold models. Protobufs are a general way to store data by Google that is much nicer to transport, as it compacts the data more efficiently and enforces a structure to the data.
Save Your Neural Network Model to JSON This can be saved to a file and later loaded via the model_from_json() function that will create a new model from the JSON specification. The weights are saved directly from the model using the save_weights() function and later loaded using the symmetrical load_weights() function.
There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:
The checkpoint files (produced e.g. by calling saver.save()
on a tf.train.Saver
object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()
), which tells TensorFlow what to do with those weights. Note that calling saver.save()
also produces a file containing a MetaGraphDef
, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorial for more details.
tf.train.write_graph()
only writes the graph structure; not the weights.
Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)
A frozen graph can be loaded using tf.import_graph_def()
. In this case, the weights are (typically) embedded in the graph, so you don't need to load a separate checkpoint.
The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the inputName
and outputName
strings that are passed to TensorFlowClassifier.initializeTensorFlow()
.
The GraphDef
is the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as the MetaGraphDef
libraries, Keras, and skflow build on these mechanisms to provide more convenient ways to save and restore an entire model.
You can try the following code:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) g_in = tf.import_graph_def(graph_def, name="") sess = tf.Session(graph=g_in)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With