The problem I am trying to solve is as follows :
I have a list trainimgs
of filenames. I have defined a
tf.RandomShuffleQueue
with its capacity=len(trainimgs)
and min_after_dequeue=0
. tf.RandomShuffleQueue
is expected to be filled by trainimgs
for a specified epochlimit
number of times. tf.RandomShuffleQueue
and does some operations on it and enqueues it to another queue. I have got that part right.1 epoch
of trainimgs
have been processed and the tf.RandomShuffleQueue
is empty, provided that the current epoch e < epochlimit
, the queue must again be filled up and the threads must work again.The good news is : I have got it working in a certain case (See PS at the end !!)
The bad news is : I think that there is a better way of doing this.
The method I am using to do this now is as follows (I have simplified the functions and have removed e image processing based preprocessing and subsequent enqueuing but the heart of the processing remains the same !!) :
with tf.Session() as sess:
train_filename_queue = tf.RandomShuffleQueue(capacity=len(trainimgs), min_after_dequeue=0, dtypes=tf.string, seed=0)
queue_size = train_filename_queue.size()
trainimgtensor = tf.constant(trainimgs)
close_queue = train_filename_queue.close()
epoch = tf.Variable(initial_value=1, trainable=False, dtype=tf.int32)
incrementepoch = tf.assign(epoch, epoch + 1, use_locking=True)
supplyimages = train_filename_queue.enqueue_many(trainimgtensor)
value = train_filename_queue.dequeue()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess, coord)
sess.run(supplyimages)
lock = threading.Lock()
threads = [threading.Thread(target=work, args=(coord, value, sess, epoch, incrementepoch, supplyimages, queue_size, lock, close_queue)) for i in range(200)]
for t in threads:
t.start()
coord.join(threads)
The work function is as follows :
def work(coord, val, sess, epoch, incrementepoch, supplyimg, q, lock,\
close_op):
while not coord.should_stop():
if sess.run(q) > 0:
filename, currepoch = sess.run([val, epoch])
filename = filename.decode(encoding='UTF-8')
print(filename + ' ' + str(currepoch))
elif sess.run(epoch) < 2:
lock.acquire()
try:
if sess.run(q) == 0:
print("The previous epoch = %d"%(sess.run(epoch)))
sess.run([incrementepoch, supplyimg])
sz = sess.run(q)
print("The new epoch = %d"%(sess.run(epoch)))
print("The new queue size = %d"%(sz))
finally:
lock.release()
else:
try:
sess.run(close_op)
except tf.errors.CancelledError:
print('Queue already closed.')
coord.request_stop()
return None
So, although this works, I have a feeling that there is a better and cleaner way to achieve this. So, in a nutshell my questions are :
P.S : It seems that this code is not Perfect after all. When I ran with 1.2 million images and 200 threads, it ran. However when I run it for 10 images and 20 threads, it gives the following error :
CancelledError (see above for traceback): RandomShuffleQueue '_0_random_shuffle_queue' is closed.
[[Node: random_shuffle_queue_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](random_shuffle_queue, Const)]]
I thought I got that covered by except tf.errors.CancelledError
. What the hell is going on here ?
I finally found out the answer. The problem was that multiple threads were clashing together on various points in the work()
function.
The following work()
function works perfectly.
def work(coord, val, sess, epoch, maxepochs, incrementepoch, supplyimg, q, lock, close_op):
print('I am thread number %s'%(threading.current_thread().name))
print('I can see a queue with size %d'%(sess.run(q)))
while not coord.should_stop():
lock.acquire()
if sess.run(q) > 0:
filename, currepoch = sess.run([val, epoch])
filename = filename.decode(encoding='UTF-8')
tid = threading.current_thread().name
print(filename + ' ' + str(currepoch) + ' thread ' + str(tid))
elif sess.run(epoch) < maxepochs:
print('Thread %s has acquired the lock'%(threading.current_thread().name))
print("The previous epoch = %d"%(sess.run(epoch)))
sess.run([incrementepoch, supplyimg])
sz = sess.run(q)
print("The new epoch = %d"%(sess.run(epoch)))
print("The new queue size = %d"%(sz))
else:
coord.request_stop()
lock.release()
return None
I recommend having a single thread calling enqueue_many epochs times enqueue the correct number of images. It can then close the queue. This would let you simplify your work function and other threads.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With