Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.GraphKeys.TRAINABLE_VARIABLES on output_graph.pb resulting in empty list

I'm trying to extract all the weights/biases from a saved model output_graph.pb.

I read the model:

def create_graph(modelFullPath):
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)

And attempted this hoping I would be able to extract all weights/biases within each layer.

with tf.Session() as sess:
    all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    print (len(all_vars))

However, I'm getting a value of 0 as the len.

Final goal is to extract the weights and biases and save it to a text file/np.arrays.

like image 916
Moondra Avatar asked Oct 11 '17 20:10

Moondra


1 Answers

The tf.import_graph_def() function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES collection (for that, you would need a MetaGraphDef). However, if output.pb contains a "frozen" GraphDef, then all of the weights will be stored in tf.constant() nodes in the graph. To extract them, you can do something like the following:

create_graph(GRAPH_DIR)

constant_values = {}

with tf.Session() as sess:
  constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
  for constant_op in constant_ops:
    constant_values[constant_op.name] = sess.run(constant_op.outputs[0])

Note that constant_values will probably contain more values than just the weights, so you may need to filter further by op.name or some other criterion.

like image 147
mrry Avatar answered Nov 01 '22 07:11

mrry