I have a trained keras model that I would like to save to a protocol buffer (.pb) file. When I do so and load the model the predictions are wrong (and different from the original model) and the weights are wrong. Here is the model type:
type(model)
> keras.engine.training.Model
Here is the code I used to freeze and save it to a .pb file.
from keras import backend as K
K.set_learning_phase(0)
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
keras_session = K.get_session()
graph = keras_session.graph
graph.as_default()
keep_var_names=None
output_names=[out.op.name for out in model.outputs]
clear_devices=True
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(keras_session, input_graph_def,
output_names, freeze_var_names)
tf.train.write_graph(frozen_graph, "model", "my_model.pb", as_text=False)
Then I read it like so:
pb_file = 'my_model.pb'
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
ops = graph.get_operations()
def get_outputs(feed_dict, output_tensor):
with tf.Session() as sess:
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
output_tensor_loc = sess.graph.get_tensor_by_name(output_tensor)
out = sess.run(output_tensor_loc, feed_dict=feed_dict)
print("Shape is ", out.shape)
return out
Then, when I compare the weights at the first convolutional layer, they have the same shape (and the shape looks correct) but the weights are different. All the weights are approximately 0:3 while in the original model at the same layer they are approximately -256:256.
get_outputs(feed_dict, 'conv1_relu/Relu:0')
Is there something wrong in the above code? Or is this whole approach wrong? I saw in a blog post someone using tf.train.Saver
, which I'm not doing. Do I need to do that? If so, how can I do that to my keras.engine.training.Model
?
Q: Is there something wrong in the above code? Or is this whole approach wrong?
A: The main problem is that tf.train.write_graph
saves the TensorFlow graph, but not the weights of your model.
Q: Do I need to do use tf.train.Saver
? If so, how can I do that to my model?
A: Yes. In addition to saving the graph (which is only necessary if your subsequent scripts do not explicitly recreate it), you should use tf.train.Saver
to save the weights of your model:
from keras import backend as K
# ... define your model in Keras and do some work
# Add ops to save and restore all the variables.
saver = tf.train.Saver() # setting var_list=None saves all variables
# Get TensorFlow session
sess = K.get_session()
# save the model's variables
save_path = saver.save(sess, "/tmp/model.ckpt")
Calling saver.save
also saves a MetaGraphDef
which can then be used to restore the graph, so it is not necessary for you to use tf.train.write_graph
. To restore the weights, simply use saver.restore
:
with tf.Session() as sess:
# restore variables from disk
saver.restore(sess, "/tmp/model.ckpt")
The fact that you are using a Keras model does not change this approach as long as you use the TensorFlow backend (you still have a TensorFlow graph and weights). For more information about saving and restoring models in TensorFlow, please see the save and restore tutorial.
Alternative (neater) way to save a Keras model
Now, since you are using a Keras model, it is perhaps more convenient to save the model with model.save('model_path.h5')
and restore it as follows:
from keras.models import load_model
# restore previously saved model
model = load_model('model_path.h5')
UPDATE: Generating a single .pb
file from the .ckpt
files
If you want to generate a single .pb
file, please use the former tf.train.Saver
approach. Once you have generated the .ckpt
files (.meta
holds the graph and .data
the weights), you can get the .pb
file by calling Morgan's function freeze_graph
as follows:
freeze_graph('/tmp', '<Comma separated output node names>')
References:
.pb
file from the .ckpt
files.If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With