Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Single Image Inference in Tensorflow [Python]

Tags:

I have already converted a pre-trained .ckpt file to .pb file freezing the model and saving the weighs as well. What I am trying to do now is to make a simple inference using that .pb file and extract and save output image. The model is a (Fully Convolutional Network for Semantic Segmentation) downloaded from here : https://github.com/MarvinTeichmann/KittiSeg . So far I have managed to, load the image, set the default tf graph and import the graph defined by the model on that, read the input and the output tensors and run the session (error here).

import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
from PIL import Image

# Read the image & get statstics
img=Image.open('/path-to-image/demoImage.png')
img.show()
width, height = img.size
print(width)
print(height)

#Plot the image
#image.show()

with tf.Graph().as_default() as graph:

        with tf.Session() as sess:

                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-FCN-model/FCN8.pb",'rb') as f:

                                #Set default graph as current graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                #sess.graph.as_default() #new line

                                # Import a graph_def into the current default Graph
                                tf.import_graph_def(graph_def, name='')

                                # Print the name of operations in the session
                                #for op in sess.graph.get_operations():

                                    #print "Operation Name :",op.name            # Operation name
                                    #print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Placeholder:0')
                                l_output = graph.get_tensor_by_name('save/Assign_38:0')

                                print "l_input", l_input
                                print "l_output", l_output
                                print
                                print

                                # Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.                              
                                result = sess.run(l_output, feed_dict={l_input : img})
                                print(results)

                                print("Inference done")

                                # Info
                                # First Tensor name : Placeholder:0
                                # Last tensor name  : save/Assign_38:0"

Can the error come from the format of the image (e.g should I convert .png to another format?). Is it another fundamental error?

like image 687
Konstantinos Monachopoulos Avatar asked Aug 15 '17 17:08

Konstantinos Monachopoulos


People also ask

What is TFLite?

TensorFlow Lite (TFLite) is a collection of tools to convert and optimize TensorFlow models to run on mobile and edge devices. It was developed by Google for internal use and later open-sourced. Today, TFLite is running on more than 4 billion devices!


1 Answers

I managed to fix the error, below is the working script to inference a single image on Fully Convolutional Networks (for whoever is interesting in an alternative segmentation algorithm from SEGNET) . This model use billinear interpolation for scaling rather than an un-pooling layer. Anyway, because the model is available to download in a .chkpt format, you must first freeze the model and save it as a .pb file. Later on, you must pass the network from TF optimizer to set Dropout probabilities to 1. Afterwards, set the correct input and output tensor name in this script and the inference works correctly, extracting the segmented image.

import tensorflow as tf # Default graph is initialized when the library is imported
import os
from tensorflow.python.platform import gfile
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import cv2

with tf.Graph().as_default() as graph: # Set default graph as graph

           with tf.Session() as sess:
                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-protobuf/FCN8_Freezed.pb",'rb') as f:

                                print("Load Image...")
                                # Read the image & get statstics
                                image = scipy.misc.imread('/Path-To-Image/uu_000010.png')
                                image = image.astype(float)
                                Input_image_shape=image.shape
                                height,width,channels = Input_image_shape

                                print("Plot image...")
                                #scipy.misc.imshow(image)

                                # Set FCN graph to the default graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                sess.graph.as_default()

                                # Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)

                                tf.import_graph_def(
                                graph_def,
                                input_map=None,
                                return_elements=None,
                                name="",
                                op_dict=None,
                                producer_op_list=None
                                )

                                # Print the name of operations in the session
                                for op in graph.get_operations():
                                        print "Operation Name :",op.name         # Operation name
                                        print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
                                l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor

                                print "Shape of input : ", tf.shape(l_input)
                                #initialize_all_variables
                                tf.global_variables_initializer()

                                # Run Kitty model on single image
                                Session_out = sess.run( l_output, feed_dict = {l_input : image} 
like image 79
Konstantinos Monachopoulos Avatar answered Sep 24 '22 20:09

Konstantinos Monachopoulos