Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Weights and Bias from Trained Meta Graph

I have successfully exported a re-trained InceptionV3 NN as a TensorFlow meta graph. I have read this protobuf back into python successfully, but I am struggling to see a way to export each layers weight and bias values, which I am assuming is stored within the meta graph protobuf, for recreating the nn outside of TensorFlow.

My workflow is as such:

Retrain final layer for new categories
Export meta graph tf.train.export_meta_graph(filename='model.meta')
Build python pb2.py using Protoc and meta_graph.proto
Load Protobuf:

import meta_graph_pb2
saved = meta_graph_pb2.CollectionDef()
with open('model.meta', 'rb') as f:
  saved.ParseFromString(f.read())

From here I can view most aspects of the graph, like node names and such, but I think my inexperience is making it difficult to track down the correct way to access the weight and bias values for each relevant layer.

like image 995
Vinny M Avatar asked Jan 05 '23 11:01

Vinny M


1 Answers

The MetaGraphDef proto doesn't actually contain the values of the weights and biases. Instead it provides a way to associate a GraphDef with the weights stored in one or more checkpoint files, written by a tf.train.Saver. The MetaGraphDef tutorial has more details, but the approximate structure is as follows:

  1. In you training program, write out a checkpoint using a tf.train.Saver. This will also write a MetaGraphDef to a .meta file in the same directory.

    saver = tf.train.Saver(...)
    # ...
    saver.save(sess, "model")
    

    You should find files called model.meta and model-NNNN (for some integer NNNN) in your checkpoint directory.

  2. In another program, you can import the MetaGraphDef you just created, and restore from a checkpoint.

    saver = tf.train.import_meta_graph("model.meta")
    saver.restore("model-NNNN")  # Or whatever checkpoint filename was written.
    

    If you want to get the value of each variable, you can (for example) find the variable in tf.all_variables() collection and pass it to sess.run() to get its value. For example, to print the values of all variables, you can do the following:

    for var in tf.all_variables():
      print var.name, sess.run(var)
    

    You could also filter tf.all_variables() to find the particular weights and biases that you're trying to extract from the model.

like image 86
mrry Avatar answered Jan 13 '23 10:01

mrry