Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

is there a version of the inference example of the Tensorflow Object detection API that can run on batches of images simultaneously?

Tags:

I have trained a faster rcnn model using the Tensorflow object detection API and am using this inference script with my frozen graph:

https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb

I intend to use it for object tracking in videos, but inference using this script is very slow since it only processes one image at a time instead of a batch of images. Is there any way to do inference on a batch of images at once ? The relevant inference function is here, I am wondering how to modify it to work with a stack of images

def run_inference_for_single_image(image, graph):
with graph.as_default():
    with tf.Session() as sess:
        # Get handles to input and output tensors
        ops = tf.get_default_graph().get_operations()
        all_tensor_names = {output.name for op in ops for output in op.outputs}
        tensor_dict = {}
        for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks']:
            tensor_name = key + ':0'
            if tensor_name in all_tensor_names:
                tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)
        if 'detection_masks' in tensor_dict:
            # The following processing is only for single image
            detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
            detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
            # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
            real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
            detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
            detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
            detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(detection_masks, detection_boxes, image.shape[0], image.shape[1])
            detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)
            # Follow the convention by adding back the batch dimension
            tensor_dict['detection_masks'] = tf.expand_dims(detection_masks_reframed, 0)
        image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

        # Run inference
        output_dict = sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)})

        # all outputs are float32 numpy arrays, so convert types as appropriate
        output_dict['num_detections'] = int(output_dict['num_detections'][0])
        output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)
        output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
        output_dict['detection_scores'] = output_dict['detection_scores'][0]
        if 'detection_masks' in output_dict:
            output_dict['detection_masks'] = output_dict['detection_masks'][0]
return output_dict
like image 283
Rohit Gupta Avatar asked Mar 08 '18 12:03

Rohit Gupta


People also ask

What is batch size in TensorFlow object detection?

The batch size is a number of samples processed before the model is updated. The number of epochs is the number of complete passes through the training dataset. The size of a batch must be more than or equal to one and less than or equal to the number of samples in the training dataset.

Can TensorFlow be used for object detection?

You can leverage the out-of-box API from TensorFlow Lite Task Library to integrate object detection models in just a few lines of code. You can also build your own custom inference pipeline using the TensorFlow Lite Interpreter Java API.

What is TensorFlow object detection API?

TensorFlow object detection is a computer vision technique that detects, locates, and traces an object from a still image or video. The method allows us to recognize how the models work and provides a fuller understanding of the image or video by detecting objects.

What is inference in TensorFlow?

The term inference refers to the process of executing a TensorFlow Lite model on-device in order to make predictions based on input data. To perform an inference with a TensorFlow Lite model, you must run it through an interpreter. The TensorFlow Lite interpreter is designed to be lean and fast.


1 Answers

Instead of passing just one numpy array of the size (1, image_width, image_heigt, 3) you can pass a numpy array with your image batch of the size (batch_size, image_width, image_heigt, 3) to the sess.run command:

output_dict = sess.run(tensor_dict, feed_dict={image_tensor: image_batch})

The output_dict will be slightly different then before, still haven't figured out how exactly. Maybe someone can help furthermore?

Edit

It seems that the output_dict gets another index which corresponds to the image number in your batch. So you'll find the boxes for a certain image in: output_dict['detection_boxes'][image_counter]

Edit2

For some reason this won't work with Mask RCNN...

like image 108
Thommy257 Avatar answered Sep 19 '22 13:09

Thommy257