Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow shuffle_batch not working

Tags:

tensorflow

import tensorflow as tf
sess = tf.Session()

def add_to_batch(image):

    print('Adding to batch')
    image_batch = tf.train.shuffle_batch([image],batch_size=5,capacity=11,min_after_dequeue=1,num_threads=1)

    # Add to summary
    tf.image_summary('images',image_batch)

    return image_batch

def get_batch():

    # Create filename queue of images to read
    filenames = [('/media/jessica/Jessica/TensorFlow/Practice/unlabeled_data_%d.png' % i) for i in range(11)]
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)

    # Read and process image
    my_image = tf.image.decode_png(value)
    my_image_float = tf.cast(my_image,tf.float32)
    image_mean = tf.reduce_mean(my_image_float)
    my_noise = tf.random_normal([96,96,3],mean=image_mean)
    my_image_noisy = my_image_float + my_noise
    print('Reading images')

    return add_to_batch(my_image_noisy)

def main ():

    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)
    writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/Practice/summary_logs', graph_def=sess.graph_def)
    merged = tf.merge_all_summaries()
    images = get_batch()
    summary_str = sess.run(merged)
    writer.add_summary(summary_str)

Hi,

I am trying to build a simple neural network in TensorFlow. I am trying to load my input images in batches. Now I am testing the code with 11 images and batch_size = 5. Eventually I will be working with 100000 images.

This bit of code is modified from TensorFlow's cifar10.py example. For some reason my code stops (doesn't terminate, it just hangs there) at tf.train.shuffle_batch([image],batch_size=5,capacity=1,min_after_dequeue=1,num_threads=1)

I've tried different combinations of batch_size, capacity, min_after_dequeue etc. but I still can't figure out what is going wrong.

Any help will be much appreciated! Thank you!

like image 231
jlhw Avatar asked Apr 16 '26 02:04

jlhw


1 Answers

It looks like the problem arises because the statement

tf.train.start_queue_runners(sess=sess)

...executes before any queue runners have been created. If you move this line after images = get_batch(), your program should work.

What is the problem here? The tf.train.shuffle_batch() function internally uses a tf.RandomShuffleQueue to produce a randomized batch. Currently, the only way to get elements into that queue is to run a step that invokes the q.enqueue() op. To make this easier, TensorFlow has a notion of "queue runners" that are implicitly collected as you build your graph, and then started with a call to tf.train.start_queue_runners(). However, calling tf.train.start_queue_runners() only starts the queue runners that have been defined at that point in time, so it must come after the code that creates the queue runners.

like image 87
mrry Avatar answered Apr 18 '26 09:04

mrry



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!