Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Convert graph (pb) to SavedModel for gcloud ml-engine predict

I trained an object detector using Cloud Machine Learning Engine according to the recent post by Google’s Derek Chow on the Google Cloud Big Data And Machine Learning Blog and now want to predict using Cloud Machine Learning Engine.

The instructions include code to export the Tensorflow graph as output_inference_graph.pb but not how to convert protobuf format (pb) into the SavedModel format required for gcloud ml-engine predict.

I reviewed the answer by Google’s @rhaertel80 for how to convert a “Tensorflow For Poets” image classification model and the answer provided by Google’s @MarkMcDonald for how to convert a “Tensorflow For Poets 2” image classification model but neither appears to work for the object detector graph (pb) described in the blog post.

How does one convert that object detector graph (pb) so it can be used or gcloud ml-engine predict, please?

like image 590
Chuck Finley Avatar asked Oct 17 '22 10:10

Chuck Finley


2 Answers

SavedModel contains a MetaGraphDef inside its structure. To create a SavedModel from a GraphDef in python you may want to use builder as described in the link.

export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph_and_variables(sess,
                                       [tag_constants.TRAINING],
                                       signature_def_map=foo_signatures,
                                       assets_collection=foo_assets)
...
with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph(["bar-tag", "baz-tag"])
...
builder.save()
like image 153
Satoshi Kataoka Avatar answered Oct 21 '22 05:10

Satoshi Kataoka


this post saved me! hope to help people who come here. I use the method exported successfulhttps://stackoverflow.com/a/48102615/6124383

https://github.com/tensorflow/tensorflow/pull/15855/commits/81ec5d20935352d71ff56fac06c36d6ff0a7ae05

def export_model(sess, architecture, saved_model_dir):
  if architecture == 'inception_v3':
    input_tensor = 'DecodeJpeg/contents:0'
  elif architecture.startswith('mobilenet_'):
    input_tensor = 'input:0'
  else:
    raise ValueError('Unknown architecture', architecture)
  in_image = sess.graph.get_tensor_by_name(input_tensor)
  inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
   out_classes = sess.graph.get_tensor_by_name('final_result:0')
  outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
   signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
  )
   legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
   # Save out the SavedModel.
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
  builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
    },
    legacy_init_op=legacy_init_op)
  builder.save()

#execute this in the final of def main(_):
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)

parser.add_argument(
      '--saved_model_dir',
      type=str,
      default='/tmp/saved_models/1/',
      help='Where to save the exported graph.'
  )
like image 32
sgffsg Avatar answered Oct 21 '22 04:10

sgffsg