Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Threading in tensorflow's input pipeline

Background

Typical input pipeline in tensorflow looks like follows:

                  tf.train.string_input_producer(list_of_filenames)
                         (creates queue of filenames)
                                   |
                                  \|/
           fixed length reader reads records from the files
                                   |
                                  \|/
    Read records are decoded and processed(eg if dealing with images then cropping,flipping etc)
                                   |
                                  \|/
            tf.train.shuffle_batch(tensors,num_threads)
        (creates a shuffling queue and returns batches of tensors) 

Problem

Q1) There is no argument of num_threads in the function tf.train.string_input_producer().Does it mean that only single thread is dedicated to reading the filenames from filename queue?

Q2) What is the scope of num_threads argument of the function tf.train.shuffle_batch() i.e. do the number of threads mentioned here are used to read,decode and process files as well or they are just used to create batches of tensors?

Q3) Is there a way to print which thread read the filenames or records from a particular file i.e. sort of a record of work done by each thread?

like image 938
Jasdeep Singh Chhabra Avatar asked Mar 25 '17 10:03

Jasdeep Singh Chhabra


1 Answers

All of your data loading operations are performed within the tensorflow graph, what you'll want to do is launch one or more threads to iterate over the reader/enqueue operations. Tensorflow provides a QueueRunner class that does exactly that. The Coordinator class allows you to manage the threads pretty trivially.

https://www.tensorflow.org/programmers_guide/threading_and_queues

This is the example code from the link above:

# Create a queue runner that will run 4 threads in parallel to enqueue
# examples.
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

# Launch the graph.
sess = tf.Session()
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# Run the training loop, controlling termination with the coordinator.
for step in xrange(1000000):
    if coord.should_stop():
        break
    sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(enqueue_threads)

If you were loading/preprocessing samples outside of the graph (in your own code, not using TF operations), then you wouldn't use QueueRunner, instead you would use your own class to enqueue data using a sess.run(enqueue_op, feed_dict={...}) command in a loop.

Q1: Number of threads is handled with: qr.create_threads(sess, coord=coord, start=True)

Q2: TF sessions are thread safe, each call to tf.run(...) sees a consistent snapshot of the current variables as of when it begin. Your QueueRunner enqueue ops can run any number of threads. They'll queue up in a thread-safe manner.

Q3: I haven't used tf.train.string_input_producer myself, but I think you'd have to request a tensor later in the graph that dequeued the data, just add that tensor to your list of requests in sess.run([train_op, dequeue_op])

like image 193
David Parks Avatar answered Nov 15 '22 03:11

David Parks