How do I import a frozen protobuf to enable it for re-training?
All the methods i've found online expect checkpoints. Is there a way to read a protobuf such that kernel and bias constants are converted to variables?
Edit 1: This is similar to the following question: How to retrain model in graph (.pb)?
I looked at DeepSpeech, which was recommended in the answers to that question. They seem to have removed support for initialize_from_frozen_model
. I couldn't find the reason.
Edit 2: I tried creating a new GraphDef object where I replace the kernels and biases with Variables:
probable_variables = [...] # kernels and biases of Conv2D and MatMul
new_graph_def = tf.GraphDef()
with tf.Session(graph=graph) as sess:
for n in sess.graph_def.node:
if n.name in probable_variables:
# create variable op
nn = new_graph_def.node.add()
nn.name = n.name
nn.op = 'VariableV2'
nn.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
nn.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=shape))
else:
nn = new_model.node.add()
nn.CopyFrom(n)
Not sure if I am on the right path. Don't know how to set trainable=True
in a NodeDef
object.
You were actually in the right direction with the snippet you provided :)
The most tricky part is to get the names of previously trainable variables. Hopefully the model was created with some high-level frameworks, like keras
or tf.slim
- they wraps their variables nicely in something like conv2d_1/kernel
, dense_1/bias
, batch_normalization/gamma
, etc.
If you're not sure, the most useful thing to do is to visualize the graph...
# read graph definition
with tf.gfile.GFile('frozen.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# now build the graph in the memory and visualize it
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="prefix")
writer = tf.summary.FileWriter('out', graph)
writer.close()
... with tensorboard:
$ tensorboard --logdir out/
and see for yourself what the graph looks like and what the naming is.
All you need is the magical library called tf.contrib.graph_editor
. Now let's say you've stored the names of previously trainable ops (that previously were variables but now they are Const
) in probable_variables
(as in your Edit 2).
Note: remember the difference between ops
, tensors
, and variables
. Ops are elements of the graph, tensor is a buffer that contains results of ops, and variables are wrappers around tensors, with 3 ops: assign
(to be called when you initialize the variable), read
(called by other ops, e.g. conv2d
), and ref tensor
(which holds the values).
Note 2: graph_editor
can only be run outside a session – you cannot make any graph modification online!
import numpy as np
import tensorflow.contrib.graph_editor as ge
# load the graphdef into memory, just as in Step 1
graph = load_graph('frozen.pb')
# create a variable for each constant, beware the naming
const_var_name_pairs = []
for name in probable_variables:
var_shape = graph.get_tensor_by_name('{}:0'.format(name)).get_shape()
var_name = '{}_a'.format(name)
var = tf.get_variable(name=var_name, shape=var_shape, dtype='float32')
const_var_name_pairs.append((name, var_name))
# from now we're going to work with GraphDef
name_to_op = dict([(n.name, n) for n in graph.as_graph_def().node])
# magic: now we swap the outputs of const and created variable
for const_name, var_name in const_var_name_pairs:
const_op = name_to_op[const_name]
var_reader_op = name_to_op[var_name + '/read']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
# Now we can safely create a session and copy the values
sess = tf.Session(graph=graph)
for const_name, var_name in const_var_name_pairs:
ts = graph.get_tensor_by_name('{}:0'.format(const_name))
var = tf.get_variable(var_name)
var.load(ts.eval(sess))
# All done! Now you can make sure everything is correct by visualizing
# and calculate outputs for some inputs.
PS: this code was not tested; however, i've been using graph_editor
and performing network surgery quite often lately, so I think it should mostly be correct :)
I have verified @FalconUA's solution with tested code. Slight modifications were needed (notably, I use the initializer
option in get_variable
to properly initialize the Variables). Here it is!
Assuming your frozen model is stored in frozen_graph.pb
:
probable_variables = [...] # kernels and biases of Conv2D and MatMul
tf_graph = load_pb('frozen_graph.pb')
const_var_name_pairs = []
with tf_graph.as_default() as g:
for name in probable_variables:
tensor = g.get_tensor_by_name('{}:0'.format(name))
with tf.Session() as sess:
tensor_as_numpy_array = sess.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
# Create TensorFlow variable initialized by values of original const.
var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \
initializer=tf.constant_initializer(tensor_as_numpy_array))
# We want to keep track of our variables names for later.
const_var_name_pairs.append((name, var_name))
# At this point, we added a bunch of tf.Variables to the graph, but they're
# not connected to anything.
# The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
# the outputs of our newly created Variables.
for const_name, var_name in const_var_name_pairs:
const_op = g.get_operation_by_name(const_name)
var_reader_op = g.get_operation_by_name(var_name + '/read')
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
Note: if you save the converted model and view it in Tensorboard or Netron, you will see that Variables have taken the Constants' places. You will also see a bunch of dangling Constants, which you can optionally remove.
I have verified that the weight values are the same between the frozen and unfrozen versions.
Here is the load_pb
function:
import tensorflow as tf
# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
with tf.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With