Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Feeding image data in tensorflow for transfer learning

I am trying to use tensorflow for transfer learning. I downloaded the pre-trained model inception3 from the tutorial. In the code, for prediction:

prediction = sess.run(softmax_tensor,{'DecodeJpeg/contents:0'}:image_data})

Is there a way to feed the png image. I tried changing DecodeJpeg to DecodePng but it did not work. Beside, what should I change if I want to feed decoded image file like a numpy array or a batch of arrays?

Thanks!!

like image 500
Geoffrey Wu Avatar asked Dec 27 '15 19:12

Geoffrey Wu


People also ask

How do you use transfer learning for image classification?

You either use the pretrained model as is or use transfer learning to customize this model to a given task. The intuition behind transfer learning for image classification is that if a model is trained on a large and general enough dataset, this model will effectively serve as a generic model of the visual world.

Is TensorFlow good for image processing?

TensorFlow compiles many different algorithms and models together, enabling the user to implement deep neural networks for use in tasks like image recognition/classification and natural language processing.


2 Answers

The shipped InceptionV3 graph used in classify_image.py only supports JPEG images out-of-the-box. There are two ways you could use this graph with PNG images:

  1. Convert the PNG image to a height x width x 3 (channels) Numpy array, for example using PIL, then feed the 'DecodeJpeg:0' tensor:

    import numpy as np
    from PIL import Image
    # ...
    
    image = Image.open("example.png")
    image_array = np.array(image)[:, :, 0:3]  # Select RGB channels only.
    
    prediction = sess.run(softmax_tensor, {'DecodeJpeg:0': image_array})
    

    Perhaps confusingly, 'DecodeJpeg:0' is the output of the DecodeJpeg op, so by feeding this tensor, you are able to feed raw image data.

  2. Add a tf.image.decode_png() op to the imported graph. Simply switching the name of the fed tensor from 'DecodeJpeg/contents:0' to 'DecodePng/contents:0' does not work because there is no 'DecodePng' op in the shipped graph. You can add such a node to the graph by using the input_map argument to tf.import_graph_def():

    png_data = tf.placeholder(tf.string, shape=[])
    decoded_png = tf.image.decode_png(png_data, channels=3)
    # ...
    
    graph_def = ...
    softmax_tensor = tf.import_graph_def(
        graph_def,
        input_map={'DecodeJpeg:0': decoded_png},
        return_elements=['softmax:0'])
    
    sess.run(softmax_tensor, {png_data: ...})
    
like image 145
mrry Avatar answered Sep 19 '22 15:09

mrry


The following code should handle of both cases.

import numpy as np
from PIL import Image

image_file = 'test.jpeg'
with tf.Session() as sess:

    #     softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    if image_file.lower().endswith('.jpeg'):
        image_data = tf.gfile.FastGFile(image_file, 'rb').read()
        prediction = sess.run('final_result:0', {'DecodeJpeg/contents:0': image_data})
    elif image_file.lower().endswith('.png'):
        image = Image.open(image_file)
        image_array = np.array(image)[:, :, 0:3]
        prediction = sess.run('final_result:0', {'DecodeJpeg:0': image_array})

    prediction = prediction[0]    
    print(prediction)

or shorter version with direct strings:

image_file = 'test.png' # or 'test.jpeg'
image_data = tf.gfile.FastGFile(image_file, 'rb').read()
ph = tf.placeholder(tf.string, shape=[])

with tf.Session() as sess:        
    predictions = sess.run(output_layer_name, {ph: image_data} )
like image 40
shahar_m Avatar answered Sep 19 '22 15:09

shahar_m