I want to train my Tensorflow model, freeze a snapshot, and then run it in feed-forward mode (no further training) with new input data. Questions:
tf.train.export_meta_graph
and tf.train.import_meta_graph
the right tools for this?collection_list
, the names of all variables that I want included in the snapshot? (Simplest for me would be to include everything.)collection_list
is specified, all collections in the model will be exported." Does that mean that if I specify no variables in collection_list
then all variables in the model are exported because they are in the default collection?to_proto()
and from_proto()
must be added only to classes that I have defined and want exported? If I am using only standard Python data types (int, float, list, dict) then is this irrelevant?Thanks in advance.
A MetaGraph contains both a TensorFlow GraphDef as well as associated metadata necessary for running computation in a graph when crossing a process boundary. It can also be used for long term storage of graphs.
A metagraph is a directed graph between a collection of sets of 'atomic' elements. Each set is a node in the graph and each directed edge represents the relationship between the sets. A simple example is given in Fig. 1, Fig.
A bit late but I'll still try to answer.
- Are
tf.train.export_meta_graph
andtf.train.import_meta_graph
the right tools for this?
I would say so. Note that tf.train.export_meta_graph
is called for you implicitly when you save a model via tf.train.Saver
. The gist is:
# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
...
# save graph and variables
# if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
saver.save(sess, save_path, global_step)
Then to restore:
save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
Note that instead of calling tf.train.import_meta_graph
you could also call the original piece of code that you used to create the model in the first place. However, I think it is more elegant to use import_meta_graph
as this way you can also restore your model even if you don't have access to the code that created it.
- Do I need to include, in
collection_list
, the names of all variables that I want included in the snapshot? (Simplest for me would be to include everything.)
No. However the question is a bit confusing: The collection_list
in export_meta_graph
is not meant to be a list of variables, but of collections (i.e. list of string keys).
Collections are quite handy, e.g. all trainable variables are automatically included in the collection tf.GraphKeys.TRAINABLE_VARIABLES
which you can get by calling:
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
or
tf.trainable_variables() # defaults to the default graph
If after restoration you need access to other intermediary results than your trainable variables, I find it quite convenient to put those into a custom collection, like this:
...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)
This collection is automatically stored (unless you specifically specify not to by omitting the name of this collection in the collection_list
argument to export_meta_graph
). So you can simply retrieve the input_
placeholder after restoration as follows:
...
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
input_ = tf.get_collection_ref('my_custom_collection')[0]
- The Tensorflow docs say: "If no
collection_list
is specified, all collections in the model will be exported." Does that mean that if I specify no variables incollection_list
then all variables in the model are exported because they are in the default collection?
Yes. Again note the subtle detail that the collection_list
is a list of collections not variables. In fact, if you only want certain variables to be saved, you can specify those when you construct the tf.train.Saver
object. From the documentation of the tf.train.Saver.__init__
:
"""Creates a `Saver`.
The constructor adds ops to save and restore variables.
`var_list` specifies the variables that will be saved and restored. It can
be passed as a `dict` or a list:
* A `dict` of names to variables: The keys are the names that will be
used to save or restore the variables in the checkpoint files.
* A list of variables: The variables will be keyed with their op name in
the checkpoint files.
- The Tensorflow docs say: "In order for a Python object to be serialized to and from MetaGraphDef, the Python class must implement
to_proto()
andfrom_proto()
methods, and register them with the system using register_proto_function." Does that mean thatto_proto()
andfrom_proto()
must be added only to classes that I have defined and want exported? If I am using only standard Python data types (int, float, list, dict) then is this irrelevant?
I have never made use of this feature, but I would say your interpretation is correct.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With