Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving tf.trainable_variables() using convert_variables_to_constants

I have a Keras model that I would like to convert to a Tensorflow protobuf (e.g. saved_model.pb).

This model comes from transfer learning on the vgg-19 network in which and the head was cut-off and trained with fully-connected+softmax layers while the rest of the vgg-19 network was frozen

I can load the model in Keras, and then use keras.backend.get_session() to run the model in tensorflow, generating the correct predictions:

frame = preprocess(cv2.imread("path/to/img.jpg")
keras_model = keras.models.load_model("path/to/keras/model.h5")

keras_prediction = keras_model.predict(frame)

print(keras_prediction)

with keras.backend.get_session() as sess:

    tvars = tf.trainable_variables()

    output = sess.graph.get_tensor_by_name('Softmax:0')
    input_tensor = sess.graph.get_tensor_by_name('input_1:0')

    tf_prediction = sess.run(output, {input_tensor: frame})
    print(tf_prediction) # this matches keras_prediction exactly

If I don't include the line tvars = tf.trainable_variables(), then the tf_prediction variable is completely wrong and doesn't match the output from keras_prediction at all. In fact all the values in the output (single array with 4 probability values) are exactly the same (~0.25, all adding to 1). This made me suspect that weights for the head are just initialized to 0 if tf.trainable_variables() is not called first, which was confirmed after inspecting the model variables. In any case, calling tf.trainable_variables() causes the tensorflow prediction to be correct.

The problem is that when I try to save this model, the variables from tf.trainable_variables() don't actually get saved to the .pb file:

with keras.backend.get_session() as sess:
    tvars = tf.trainable_variables()

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['Softmax'])
    graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=False)

What I am asking is, how can I save a Keras model as a Tensorflow protobuf with the tf.training_variables() intact?

Thanks so much!

like image 488
user2205763 Avatar asked Aug 20 '17 06:08

user2205763


People also ask

What is tf Trainable_variables ()?

Description. When passed trainable=True , the Variable() constructor automatically adds new variables to the graph collection GraphKeys. TRAINABLE_VARIABLES . This convenience function returns the contents of that collection.

How do I save a TensorFlow object?

Remember that Tensorflow variables are only alive inside a session. So, you have to save the model inside a session by calling save method on saver object you just created.

How do I get tf variable value?

To get the current value of a variable x in TensorFlow 2, you can simply print it with print(x) . This prints a representation of the tf. Variable object that also shows you its current value.

What is SavedModel in TensorFlow?

A SavedModel contains a complete TensorFlow program, including trained parameters (i.e, tf. Variable s) and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with TFLite, TensorFlow. js, TensorFlow Serving, or TensorFlow Hub.


1 Answers

So your approach of freezing the variables in the graph (converting to constants), should work, but isn't necessary and is trickier than the other approaches. (more on this below). If your want graph freezing for some reason (e.g. exporting to a mobile device), I'd need more details to help debug, as I'm not sure what implicit stuff Keras is doing behind the scenes with your graph. However, if you want to just save and load a graph later, I can explain how to do that, (though no guarantees that whatever Keras is doing won't screw it up..., happy to help debug that).

So there are actually two formats at play here. One is the GraphDef, which is used for Checkpointing, as it does not contain metadata about inputs and outputs. The other is a MetaGraphDef which contains metadata and a graph def, the metadata being useful for prediction and running a ModelServer (from tensorflow/serving).

In either case you need to do more than just call graph_io.write_graph because the variables are usually stored outside the graphdef.

There are wrapper libraries for both these use cases. tf.train.Saver is primarily used for saving and restoring checkpoints.

However, since you want prediction, I would suggest using a tf.saved_model.builder.SavedModelBuilder to build a SavedModel binary. I've provided some boiler plate for this below:

from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY as DEFAULT_SIG_DEF
builder = tf.saved_model.builder.SavedModelBuilder('./mymodel')
with keras.backend.get_session() as sess:
  output = sess.graph.get_tensor_by_name('Softmax:0')
  input_tensor = sess.graph.get_tensor_by_name('input_1:0')
  sig_def = tf.saved_model.signature_def_utils.predict_signature_def(
    {'input': input_tensor},
    {'output': output}
  )
  builder.add_meta_graph_and_variables(
      sess, tf.saved_model.tag_constants.SERVING,
      signature_def_map={
        DEFAULT_SIG_DEF: sig_def
      }
  )
builder.save()

After running this code you should have a mymodel/saved_model.pb file as well as a directory mymodel/variables/ with protobufs corresponding to the variable values.

Then to load the model again, simply use tf.saved_model.loader:

# Does Keras give you the ability to start with a fresh graph?
# If not you'll need to do this in a separate program to avoid
# conflicts with the old default graph
with tf.Session(graph=tf.Graph()):
  meta_graph_def = tf.saved_model.loader.load(
      sess, 
      tf.saved_model.tag_constants.SERVING,
      './mymodel'
  )
  # From this point variables and graph structure are restored

  sig_def = meta_graph_def.signature_def[DEFAULT_SIG_DEF]
  print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame}))

Obviously there's a more efficient prediction available with this code through tensorflow/serving, or Cloud ML Engine, but this should work. It's possible that Keras is doing something under the hood which will interfere with this process as well, and if so we'd like to hear about it (and I'd like to make sure that Keras users are able to freeze graphs as well, so if you want to send me a gist with your full code or something maybe I can find someone who knows Keras well to help me debug.)

EDIT: You can find an end to end example of this here: https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/keras/trainer/model.py#L85

like image 109
Eli Bixby Avatar answered Nov 16 '22 01:11

Eli Bixby