Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow : Enqueuing and dequeuing a queue from multiple threads

The problem I am trying to solve is as follows : I have a list trainimgs of filenames. I have defined a

  • tf.RandomShuffleQueue with its capacity=len(trainimgs) and min_after_dequeue=0.
  • This tf.RandomShuffleQueue is expected to be filled by trainimgs for a specified epochlimit number of times.
  • A number of threads are expected to work in parallel. Each thread dequeues an element from the tf.RandomShuffleQueue and does some operations on it and enqueues it to another queue. I have got that part right.
  • However once 1 epoch of trainimgs have been processed and the tf.RandomShuffleQueue is empty, provided that the current epoch e < epochlimit, the queue must again be filled up and the threads must work again.

The good news is : I have got it working in a certain case (See PS at the end !!)

The bad news is : I think that there is a better way of doing this.

The method I am using to do this now is as follows (I have simplified the functions and have removed e image processing based preprocessing and subsequent enqueuing but the heart of the processing remains the same !!) :

with tf.Session() as sess:
    train_filename_queue = tf.RandomShuffleQueue(capacity=len(trainimgs), min_after_dequeue=0, dtypes=tf.string, seed=0)
    queue_size = train_filename_queue.size()
    trainimgtensor = tf.constant(trainimgs)
    close_queue = train_filename_queue.close()
    epoch = tf.Variable(initial_value=1, trainable=False, dtype=tf.int32)
    incrementepoch = tf.assign(epoch, epoch + 1, use_locking=True)
    supplyimages = train_filename_queue.enqueue_many(trainimgtensor)
    value = train_filename_queue.dequeue()

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess, coord)
    sess.run(supplyimages)
    lock = threading.Lock()
    threads = [threading.Thread(target=work, args=(coord, value, sess, epoch, incrementepoch, supplyimages, queue_size, lock, close_queue)) for  i in range(200)] 
    for t in threads:
        t.start()
    coord.join(threads)

The work function is as follows :

def work(coord, val, sess, epoch, incrementepoch, supplyimg, q, lock,\
         close_op):
while not coord.should_stop():
    if sess.run(q) > 0:
        filename, currepoch = sess.run([val, epoch])
        filename = filename.decode(encoding='UTF-8')
        print(filename + ' ' + str(currepoch))
    elif sess.run(epoch) < 2:
        lock.acquire()
        try:
            if sess.run(q) == 0:
                print("The previous epoch = %d"%(sess.run(epoch)))
                sess.run([incrementepoch, supplyimg])
                sz = sess.run(q)
                print("The new epoch = %d"%(sess.run(epoch)))
                print("The new queue size = %d"%(sz))
        finally:
            lock.release()
    else:
        try:
            sess.run(close_op)
        except tf.errors.CancelledError:
            print('Queue already closed.')
        coord.request_stop()
return None

So, although this works, I have a feeling that there is a better and cleaner way to achieve this. So, in a nutshell my questions are :

  1. Is there a simpler and cleaner way of achieving this task in TensorFlow ?
  2. Is there any problem with this code's logic ? I am not very experienced with multithreading scenarios, so any obvious faults which have skipped my attention would be very helpful to me.

P.S : It seems that this code is not Perfect after all. When I ran with 1.2 million images and 200 threads, it ran. However when I run it for 10 images and 20 threads, it gives the following error :

CancelledError (see above for traceback): RandomShuffleQueue '_0_random_shuffle_queue' is closed.
     [[Node: random_shuffle_queue_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](random_shuffle_queue, Const)]]

I thought I got that covered by except tf.errors.CancelledError. What the hell is going on here ?

like image 405
Ujjwal Avatar asked Feb 28 '17 16:02

Ujjwal


2 Answers

I finally found out the answer. The problem was that multiple threads were clashing together on various points in the work() function. The following work() function works perfectly.

def work(coord, val, sess, epoch, maxepochs, incrementepoch, supplyimg, q, lock, close_op):
    print('I am thread number %s'%(threading.current_thread().name))
    print('I can see a queue with size %d'%(sess.run(q)))
    while not coord.should_stop():
        lock.acquire()
        if sess.run(q) > 0:
            filename, currepoch = sess.run([val, epoch])
            filename = filename.decode(encoding='UTF-8')
            tid = threading.current_thread().name
            print(filename + ' ' + str(currepoch) + ' thread ' + str(tid))
        elif sess.run(epoch) < maxepochs:
            print('Thread %s has acquired the lock'%(threading.current_thread().name))
            print("The previous epoch = %d"%(sess.run(epoch)))
            sess.run([incrementepoch, supplyimg])
            sz = sess.run(q)
            print("The new epoch = %d"%(sess.run(epoch)))
            print("The new queue size = %d"%(sz))
    else:
            coord.request_stop()
        lock.release()

    return None
like image 142
Ujjwal Avatar answered Sep 19 '22 13:09

Ujjwal


I recommend having a single thread calling enqueue_many epochs times enqueue the correct number of images. It can then close the queue. This would let you simplify your work function and other threads.

like image 20
Alexandre Passos Avatar answered Sep 21 '22 13:09

Alexandre Passos