Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to properly serve an object detection model from Tensorflow Object Detection API?

I am using Tensorflow Object Detection API(github.com/tensorflow/models/tree/master/object_detection) with one object detection task. Right now I am having problem on serving the detection model I trained with Tensorflow Serving(tensorflow.github.io/serving/).

1. The first issue I am encountering is about exporting the model to servable files. The object detection api kindly included the export script so that I am able to convert ckpt files to pb files with variables. However, the output files will not have any content in 'variables' folder. I though this was a bug and reported it on Github, but it seems they interned to convert variables to constants so that there will be no variables. The detail can be found HERE.

The flags I was using when exporting the saved model is as follows:

    CUDA_VISIBLE_DEVICES=0 python export_inference_graph.py \
        --input_type image_tensor \
            --pipeline_config_path configs/rfcn_resnet50_car_Jul_20.config \
                --checkpoint_path resnet_ckpt/model.ckpt-17586 \
                    --inference_graph_path serving_model/1 \
                      --export_as_saved_model True

It runs perfectly fine in python when I switch --export_as_saved_model to False.

But still, I am having issue with serving the model.

When I was trying to run:

~/serving$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=gan --model_base_path=<my_model_path>

I got:

2017-07-27 16:11:53.222439: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2017-07-27 16:11:53.222497: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:165] The specified SavedModel has no variables; no checkpoints were restored.
2017-07-27 16:11:53.222502: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2017-07-27 16:11:53.229463: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 281805 microseconds.
2017-07-27 16:11:53.229508: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: gan version: 1}
2017-07-27 16:11:53.244716: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:9000 ...

I think the model was not properly loaded since it shows "The specified SavedModel has no variables; no checkpoints were restored."

But since we have converted all variables into constants, it seems reasonable. I am not sure here.

2. I was not able to use client to call server and do detection on a sample image.

The client scrip is listed below:

from __future__ import print_function
from __future__ import absolute_import

# Communication to TensorFlow server via gRPC
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
from PIL import Image
# TensorFlow serving stuff to send messages
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2


# Command line arguments
tf.app.flags.DEFINE_string('server', 'localhost:9000',
                       'PredictionService host:port')
tf.app.flags.DEFINE_string('image', '', 'path to image in JPEG format')
FLAGS = tf.app.flags.FLAGS


def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
    (im_height, im_width, 3)).astype(np.uint8)

def main(_):
    host, port = FLAGS.server.split(':')
    channel = implementations.insecure_channel(host, int(port))
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

    # Send request
    request = predict_pb2.PredictRequest()
    image = Image.open(FLAGS.image)
    image_np = load_image_into_numpy_array(image)
    image_np_expanded = np.expand_dims(image_np, axis=0)
    # Call GAN model to make prediction on the image
    request.model_spec.name = 'gan'
    request.model_spec.signature_name = 'predict_images'
    request.inputs['inputs'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image_np_expanded))

    result = stub.Predict(request, 60.0)  # 60 secs timeout
    print(result)


if __name__ == '__main__':
    tf.app.run()

To match request.model_spec.signature_name = 'predict_images', I modified the exporter.py script in object detection api (github.com/tensorflow/models/blob/master/object_detection/exporter.py) started at line 289 from:

          signature_def_map={
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },

To:

          signature_def_map={
          'predict_images': detection_signature,
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },

Since I have no idea how to call a default signature key.

When I run the following command:

bazel-bin/tensorflow_serving/example/client --server=localhost:9000 --image=<my_image_file>

I got following error message:

    Traceback (most recent call last):
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 54, in <module>
    tf.app.run()
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 49, in main
    result = stub.Predict(request, 60.0)  # 60 secs timeout
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__
    self._request_serializer, self._response_deserializer)
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary
    raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0")

Not quite sure what's going on here.

Initially I though maybe my client script is not correct, after I found the AbortionError is from github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc. Seems I got this error when building the graph. So it might be caused by the first issue I have.

I am new to this stuff, so I am really confused. I think I might be wrong at start. Is there any way that I could properly export and serve the detection model? Any suggestions will be of great help!

like image 732
Xinyao Wang Avatar asked Jul 27 '17 23:07

Xinyao Wang


Video Answer


2 Answers

The current exporter code doesn't populate signature field properly. So serving using model server doesn't work. Apologies to that. A new version to better support exporting the model is coming. It includes some important fixes and improvements needed for serving, especially serving on Cloud ML Engine. See the github issue if you want to try an early version of it.

For "The specified SavedModel has no variables; no checkpoints were restored." message, it is expected due to the exact reason you said, as all variables are converted into constants in the graph. For the error of "FeedInputs: unable to find feed output ToFloat:0", make sure you use TF 1.2 when building the model server.

like image 183
yxshi Avatar answered Oct 23 '22 02:10

yxshi


  1. Your idea is fine. It' ok to have that warning.

  2. The issue is that the input needs to be converted to uint8 as the model expects. Here is the code snippet that worked for me.

request = predict_pb2.PredictRequest()
request.model_spec.name = 'gan'
request.model_spec.signature_name = 
    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

image = Image.open('any.jpg')
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)

request.inputs['inputs'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image_np_expanded, 
        shape=image_np_expanded.shape, dtype='uint8'))

This part is important for you shape=image_np_expanded.shape, dtype='uint8' and make sure to pull the latest update for serving.

like image 42
Sumsuddin Shojib Avatar answered Oct 23 '22 02:10

Sumsuddin Shojib