Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does the distorted_inputs() function in the TensorFlow CIFAR-10 example tutorial get 128 images per batch?

I was going through the CIFAR-10 example at TensorFlow getting started guide for CNN

Now in the train function in cifar10_train.py we get images as

images,labels = cifar10.distorted_inputs()

In the distorted_inputs() function we generate the filenames in a queue and then read a single record as

 # Create a queue that produces the filenames to read.
 filename_queue = tf.train.string_input_producer(filenames)

 # Read examples from files in the filename queue.
 read_input = cifar10_input.read_cifar10(filename_queue)
 reshaped_image = tf.cast(read_input.uint8image, tf.float32)

When I add debugging code, the read_input variable contains only 1 record with an image and its height, width, and label name.

The example then applies some distortion to the read image/record and then passes it to the _generate_image_and_label_batch() function.

This function then returns a 4D Tensor of shape [batch_size, 32, 32, 3] where batch_size = 128.

The above function utilizes the tf.train.shuffle_batch() function when returns the batch.

My question is where do the extra records come from in the tf.train.shuffle_batch() function? We are not passing it any filename or reader object.

Can someone shed some light on how we go from 1 record to 128 records? I looked into the documentation but didn't understand.

like image 296
t0mkaka Avatar asked Jan 04 '16 10:01

t0mkaka


1 Answers

The tf.train.shuffle_batch() function can be used to produce (one or more) tensors containing a batch of inputs. Internally, tf.train.shuffle_batch() creates a tf.RandomShuffleQueue, on which it calls q.enqueue() with the image and label tensors to enqueue a single element (image-label pair). It then returns the result of q.dequeue_many(batch_size), which concatenates batch_size randomly selected elements (image-label pairs) into a batch of images and a batch of labels.

Note that, although it looks from the code like read_input and filename_queue have a functional relationship, there is an additional wrinkle. Simply evaluating the result of tf.train.shuffle_batch() will block forever, because no elements have been added to the internal queue. To simplify this, when you call tf.train.shuffle_batch(), TensorFlow will add a QueueRunner to an internal collection in the graph. A later call to tf.train.start_queue_runners() (e.g. here in cifar10_train.py) will start a thread that adds elements to the queue, and enables training to proceed. The Threading and Queues HOWTO has more information on how this works.

like image 79
mrry Avatar answered Oct 26 '22 18:10

mrry