Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Metagraph Fundamentals

Tags:

tensorflow

I want to train my Tensorflow model, freeze a snapshot, and then run it in feed-forward mode (no further training) with new input data. Questions:

  1. Are tf.train.export_meta_graph and tf.train.import_meta_graph the right tools for this?
  2. Do I need to include, in collection_list, the names of all variables that I want included in the snapshot? (Simplest for me would be to include everything.)
  3. The Tensorflow docs say: "If no collection_list is specified, all collections in the model will be exported." Does that mean that if I specify no variables in collection_list then all variables in the model are exported because they are in the default collection?
  4. The Tensorflow docs say: "In order for a Python object to be serialized to and from MetaGraphDef, the Python class must implement to_proto() and from_proto() methods, and register them with the system using register_proto_function." Does that mean that to_proto() and from_proto() must be added only to classes that I have defined and want exported? If I am using only standard Python data types (int, float, list, dict) then is this irrelevant?

Thanks in advance.

like image 672
Ron Cohen Avatar asked Oct 01 '16 00:10

Ron Cohen


People also ask

What is MetaGraph in TensorFlow?

A MetaGraph contains both a TensorFlow GraphDef as well as associated metadata necessary for running computation in a graph when crossing a process boundary. It can also be used for long term storage of graphs.

What is a MetaGraph?

A metagraph is a directed graph between a collection of sets of 'atomic' elements. Each set is a node in the graph and each directed edge represents the relationship between the sets. A simple example is given in Fig. 1, Fig.


1 Answers

A bit late but I'll still try to answer.

  1. Are tf.train.export_meta_graph and tf.train.import_meta_graph the right tools for this?

I would say so. Note that tf.train.export_meta_graph is called for you implicitly when you save a model via tf.train.Saver. The gist is:

# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
    ...
    # save graph and variables
    # if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
    saver.save(sess, save_path, global_step)

Then to restore:

save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
    saver.restore(sess, latest_checkpoint)

Note that instead of calling tf.train.import_meta_graph you could also call the original piece of code that you used to create the model in the first place. However, I think it is more elegant to use import_meta_graph as this way you can also restore your model even if you don't have access to the code that created it.


  1. Do I need to include, in collection_list, the names of all variables that I want included in the snapshot? (Simplest for me would be to include everything.)

No. However the question is a bit confusing: The collection_list in export_meta_graph is not meant to be a list of variables, but of collections (i.e. list of string keys).

Collections are quite handy, e.g. all trainable variables are automatically included in the collection tf.GraphKeys.TRAINABLE_VARIABLES which you can get by calling:

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

or

tf.trainable_variables()  # defaults to the default graph

If after restoration you need access to other intermediary results than your trainable variables, I find it quite convenient to put those into a custom collection, like this:

...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)

This collection is automatically stored (unless you specifically specify not to by omitting the name of this collection in the collection_list argument to export_meta_graph). So you can simply retrieve the input_ placeholder after restoration as follows:

...
with tf.Session() as sess:
    saver.restore(sess, latest_checkpoint)
    input_ = tf.get_collection_ref('my_custom_collection')[0]

  1. The Tensorflow docs say: "If no collection_list is specified, all collections in the model will be exported." Does that mean that if I specify no variables in collection_list then all variables in the model are exported because they are in the default collection?

Yes. Again note the subtle detail that the collection_list is a list of collections not variables. In fact, if you only want certain variables to be saved, you can specify those when you construct the tf.train.Saver object. From the documentation of the tf.train.Saver.__init__:

 """Creates a `Saver`.

    The constructor adds ops to save and restore variables.

    `var_list` specifies the variables that will be saved and restored. It can
    be passed as a `dict` or a list:

    * A `dict` of names to variables: The keys are the names that will be
      used to save or restore the variables in the checkpoint files.
    * A list of variables: The variables will be keyed with their op name in
      the checkpoint files.

  1. The Tensorflow docs say: "In order for a Python object to be serialized to and from MetaGraphDef, the Python class must implement to_proto() and from_proto() methods, and register them with the system using register_proto_function." Does that mean that to_proto() and from_proto() must be added only to classes that I have defined and want exported? If I am using only standard Python data types (int, float, list, dict) then is this irrelevant?

I have never made use of this feature, but I would say your interpretation is correct.

like image 169
kafman Avatar answered Oct 24 '22 03:10

kafman