I'm trying to extract all the weights/biases from a saved model output_graph.pb
.
I read the model:
def create_graph(modelFullPath):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)
And attempted this hoping I would be able to extract all weights/biases within each layer.
with tf.Session() as sess:
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print (len(all_vars))
However, I'm getting a value of 0 as the len.
Final goal is to extract the weights and biases and save it to a text file/np.arrays.
The tf.import_graph_def()
function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES
collection (for that, you would need a MetaGraphDef
). However, if output.pb
contains a "frozen" GraphDef
, then all of the weights will be stored in tf.constant()
nodes in the graph. To extract them, you can do something like the following:
create_graph(GRAPH_DIR)
constant_values = {}
with tf.Session() as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
Note that constant_values
will probably contain more values than just the weights, so you may need to filter further by op.name
or some other criterion.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With