Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to load and use a saved model on tensorflow?

Tags:

I have found 2 ways to save a model in Tensorflow: tf.train.Saver() and SavedModelBuilder. However, I can't find documentation on using the model after it being loaded the second way.

Note: I want to use SavedModelBuilder way because I train the model in Python and will use it at serving time in another language (Go), and it seems that SavedModelBuilder is the only way in that case.

This works great with tf.train.Saver() (first way):

model = tf.add(W * x, b, name="finalnode")  # save saver = tf.train.Saver() saver.save(sess, "/tmp/model")  # load saver.restore(sess, "/tmp/model")  # IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT # I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.  model = graph.get_tensor_by_name("finalnode:0") sess.run(model, {x: [5, 6, 7]}) 

tf.saved_model.builder.SavedModelBuilder() is defined in the Readme but after loading the model with tf.saved_model.loader.load(sess, [], export_dir)), I can't find documentation on getting back at the nodes (see "finalnode" in the code above)

like image 836
Thomas Avatar asked Aug 16 '17 04:08

Thomas


2 Answers

What was missing was the signature

# Saving builder = tf.saved_model.builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {         "model": tf.saved_model.signature_def_utils.predict_signature_def(             inputs= {"x": x},             outputs= {"finalnode": model})         }) builder.save()  # loading with tf.Session(graph=tf.Graph()) as sess:     tf.saved_model.loader.load(sess, ["tag"], export_dir)     graph = tf.get_default_graph()     x = graph.get_tensor_by_name("x:0")     model = graph.get_tensor_by_name("finalnode:0")     print(sess.run(model, {x: [5, 6, 7, 8]})) 
like image 158
Thomas Avatar answered Sep 28 '22 11:09

Thomas


Here's the code snippet to load and restore/predict models using the simple_save

#Save the model: tf.saved_model.simple_save(sess, export_dir=saveModelPath,                                    inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,                                            "isTrainingBool": isTraining},                                    outputs={"predictedClassBatch": predClass}) 

Note that using simple_save sets certain default values (this can be seen at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py)

Now, to restore and use the inputs/outputs dict:

from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import signature_constants  with tf.Session() as sess:   model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.    inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name   inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)    inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name   inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)    isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name   isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)    outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name   outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)    outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})    print("predicted classes:", outPred) 

Note: the default signature_def was needed to make use of the tensor names specified in the input & output dicts.

like image 38
Anurag Avatar answered Sep 28 '22 10:09

Anurag