Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I get a TensorFlow/Keras model that takes images as input to serve predictions on Cloud ML Engine?

There are multiple questions (examples: 1, 2, 3, 4, 5, 6, etc.) trying to address the question of how to handle image data when serving predictions for TensorFlow/Keras models in Cloud ML Engine.

Unfortunately, some of the answers are out-of-date and none of them comprehensively addresses the problem. The purpose of this post is to provide a comprehensive, up-to-date answer for future reference.

like image 509
rhaertel80 Avatar asked Jul 19 '18 22:07

rhaertel80


1 Answers

This answer is going to focus on Estimators, which are high-level APIs for writing TensorFlow code and currently the recommended way. In addition, Keras uses Estimators to export models for serving.

This answer is going to be divided into two parts:

  1. How to write the input_fn.
  2. Client code for sending requests once the model is deployed.

How to Write the input_fn

The exact details of your input_fn will depend on your unique requirements. For instance, you may do image decoding and resizing client side, you might use JPG vs. PNG, you may expect a specific size of image, you may have additional inputs besides images, etc. We will focus on a fairly general approach that accepts various image formats at a variety of sizes. Thus, the following generic code should be fairly easily to adapt to any of the more specific scenarios.

HEIGHT = 199
WIDTH = 199
CHANNELS = 1

def serving_input_receiver_fn():

  def decode_and_resize(image_str_tensor):
     """Decodes jpeg string, resizes it and returns a uint8 tensor."""
     image = tf.image.decode_jpeg(image_str_tensor, channels=CHANNELS)
     image = tf.expand_dims(image, 0)
     image = tf.image.resize_bilinear(
         image, [HEIGHT, WIDTH], align_corners=False)
     image = tf.squeeze(image, squeeze_dims=[0])
     image = tf.cast(image, dtype=tf.uint8)
     return image

 # Optional; currently necessary for batch prediction.
 key_input = tf.placeholder(tf.string, shape=[None]) 
 key_output = tf.identity(key_input)

 input_ph = tf.placeholder(tf.string, shape=[None], name='image_binary')
 images_tensor = tf.map_fn(
      decode_and_resize, input_ph, back_prop=False, dtype=tf.uint8)
 images_tensor = tf.image.convert_image_dtype(images_tensor, dtype=tf.float32) 

 return tf.estimator.export.ServingInputReceiver(
     {'images': images_tensor},
     {'bytes': input_ph})

If you've saved out Keras model and would like to convert it to a SavedModel, use the following:

KERAS_MODEL_PATH='/path/to/model'
MODEL_DIR='/path/to/store/checkpoints'
EXPORT_PATH='/path/to/store/savedmodel'

# If you are invoking this from your training code, use `keras_model=model` instead.
estimator = keras.estimator.model_to_estimator(
    keras_model_path=KERAS_MODEL_PATH,
    model_dir=MODEL_DIR)
estimator.export_savedmodel(
    EXPORT_PATH,
    serving_input_receiver_fn=serving_input_receiver_fn) 

Sending Requests (Client Code)

The body of the requests sent to service will look like the following:

{
  "instances": [
    {"bytes": {"b64": "<base64 encoded image>"}},  # image 1
    {"bytes": {"b64": "<base64 encoded image>"}}   # image 2 ...        
  ]
}

You can test your model / requests out locally before deploying to speed up the debugging process. For this, we'll use gcloud ml-engine local predict. However, before we do that, please note the gclouds data format is a slight transformation from the request body shown above. gcloud treats each line of the input file as an instance/image and then constructs the JSON from each line. So instead of the above request, we will instead have:

{"bytes": {"b64": "<base64 encoded image>"}}
{"bytes": {"b64": "<base64 encoded image>"}}

gcloud will transform this file into the request above. Here is some example Python code that can produce a file suitable for use with gcloud:

import base64
import sys

for filename in sys.argv[1:]:
  with open(filename, 'rb') as f:
    img_data = f.read()
    print('{"bytes": {"b64": "%s"}}' % (base64.b64encode(img_data),))

(Let's call this file to_instances.py)

To test the model with predictions:

python to_instances.py img1.jpg img2.jpg > instances.json
gcloud ml-engine local predict --model-dir /path/to/model --json-instances=instances.json

After we've finished debugging, we can deploy the model to the cloud using gcloud ml-engine models create and gcloud ml-engine versions create as described in the documentation.

At this point, you can use your desired client to send requests to your model on the service. Note, that this will require an authentication token. We'll examine a few examples in various languages. In each case, we'll assume your model is called my_model.

gcloud

This is pretty close to the same as local predict:

python to_instances.py img1.jpg img2.jpg > instances.json
gcloud ml-engine predict --model my_model --json-instances=instances.json    

curl

We'll need a script like to_instances.py to convert images; let's call it to_payload.py:

import base64
import json 
import sys

instances = []
for filename in sys.argv[1:]:
  with open(filename, 'rb') as f:
    img_data = f.read()
    instances.append(base64.b64encode(img_data))
print(json.dumps({"instances": instances}))

python to_request.py img1.jpg img2.jpg > payload.json

curl -m 180 -X POST -v -k -H "Content-Type: application/json" \ -d @payload.json \ -H "Authorization: Bearer gcloud auth print-access-token" \ https://ml.googleapis.com/v1/projects/${YOUR_PROJECT}/models/my_model:predict

Python

import base64
PROJECT = "my_project"
MODEL = "my_model"

img_data = ... # your client will have its own way to get image data.

# Create the ML Engine service object.
# To authenticate set the environment variable
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
service = googleapiclient.discovery.build('ml', 'v1')
name = 'projects/{}/models/{}'.format(PROJECT, MODEL)

response = service.projects().predict(
    name=name,
    body={'instances': [{'b64': base64.encode(img_data)}]}
).execute()

if 'error' in response:
    raise RuntimeError(response['error'])

return response['predictions']

Javascript/Java/C#

Sending requests in Javascript/Java/C# are covered elsewhere (Javascript, Java, C#, respectively) and those examples should be straightforward to adapt.

like image 169
rhaertel80 Avatar answered May 03 '23 18:05

rhaertel80