I have found 2 ways to save a model in Tensorflow: tf.train.Saver()
and SavedModelBuilder
. However, I can't find documentation on using the model after it being loaded the second way.
Note: I want to use SavedModelBuilder
way because I train the model in Python and will use it at serving time in another language (Go), and it seems that SavedModelBuilder
is the only way in that case.
This works great with tf.train.Saver()
(first way):
model = tf.add(W * x, b, name="finalnode") # save saver = tf.train.Saver() saver.save(sess, "/tmp/model") # load saver.restore(sess, "/tmp/model") # IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT # I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY. model = graph.get_tensor_by_name("finalnode:0") sess.run(model, {x: [5, 6, 7]})
tf.saved_model.builder.SavedModelBuilder()
is defined in the Readme but after loading the model with tf.saved_model.loader.load(sess, [], export_dir)
), I can't find documentation on getting back at the nodes (see "finalnode"
in the code above)
What was missing was the signature
# Saving builder = tf.saved_model.builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= { "model": tf.saved_model.signature_def_utils.predict_signature_def( inputs= {"x": x}, outputs= {"finalnode": model}) }) builder.save() # loading with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, ["tag"], export_dir) graph = tf.get_default_graph() x = graph.get_tensor_by_name("x:0") model = graph.get_tensor_by_name("finalnode:0") print(sess.run(model, {x: [5, 6, 7, 8]}))
Here's the code snippet to load and restore/predict models using the simple_save
#Save the model: tf.saved_model.simple_save(sess, export_dir=saveModelPath, inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train, "isTrainingBool": isTraining}, outputs={"predictedClassBatch": predClass})
Note that using simple_save sets certain default values (this can be seen at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)
Now, to restore and use the inputs/outputs dict:
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import signature_constants with tf.Session() as sess: model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default. inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name) inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name) isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name) outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name) outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False}) print("predicted classes:", outPred)
Note: the default signature_def was needed to make use of the tensor names specified in the input & output dicts.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With