Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Batch prediction using a trained Object Detection APIs model and TF 2

I successfully trained a model using Object Detection APIs for TF 2 on TPUs which is saved as a .pb (SavedModel format). I then load it back using tf.saved_model.load and it works fine when predicting boxes using a single image converted to a tensor with shape (1, w, h, 3).

import tensorflow as tf
import numpy as np

# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')

image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine

Problem is I need to make this a batch prediction to scale it to half a million images, but the input signature of this model seems to be limited to handling only data with shape (1, w, h, 3). This also means that I can't use batch processing with Tensorflow Serving. How can I solve this problem? Can I merely change the model Signature to handle batches of data?

All work (loading model + predictions) was performed inside the official container released with the Object Detection APIs (from here)

like image 653
Alberto Avatar asked Sep 02 '20 09:09

Alberto


People also ask

What is TF object detection API?

The TensorFlow object detection API is the framework for creating a deep learning network that solves object detection problems. There are already pretrained models in their framework which they refer to as Model Zoo.

How is TensorFlow used in object detection?

TensorFlow object detection is a computer vision technique that detects, locates, and traces an object from a still image or video. The method allows us to recognize how the models work and provides a fuller understanding of the image or video by detecting objects.


Video Answer


1 Answers

I have met this issue recently. When you use exporter_main_v2.py to convert checkpoint files to .pb file, it will call exporter_lib_v2.py. I figured that in file exporter_lib_v2.py (here), TF2 hard fixed the input signature with shape [1, None, None, 3]. We have to change it to [None, None, None, 3]

Need to modify those lines in that file (138, 162, 170, 185) from 1 to None. Then rebuild the TF2 Object Detector API Repo (link) and use new built version to export .pb again.

like image 146
Vo Minh Thanh Avatar answered Oct 21 '22 03:10

Vo Minh Thanh