Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow batch_join's allow_smaller_final_batch doesn't work?

I am using the tenosrlfow queue to process my data, and I need to get the final batch whose size is smaller than the batch size, but I can only get 5 batch size, the final batch cann't be got. I don't understand what's the problem about that.

data = np.arange(105)
data_placeholder = tf.placeholder(dtype=tf.int64, shape=[None,])

queue = tf.FIFOQueue(capacity=200,dtypes=tf.int64,shapes=())
enqueue_op = queue.enqueue_many([data_placeholder])

data_list = []
data_ = queue.dequeue()

data_list.append([data_])
batch_data = tf.train.batch_join(data_list,batch_size=20, capacity=100 ,allow_smaller_final_batch=True)

sess = tf.Session()

coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess,coord)

step = 0
under = 0
uper = 0
enqueu_step = len(data)//20 + 1
while step < enqueu_step:
    uper = uper + 20
    sess.run(enqueue_op, feed_dict={data_placeholder:data[under:uper]})
    print("enque step=%d/%d %d-%d" %(step, enqueu_step,under, uper))
    step = step + 1
    under = uper
i = 0
while i < enqueu_step:
    _data = sess.run(batch_data)
    print("setp=%d/%d shape=%s" % (i, enqueu_step,_data.shape))

    i = i + 1
print("end")
like image 829
JinbiaoYu Avatar asked Jan 18 '26 19:01

JinbiaoYu


1 Answers

I haven't checked your entire code but if I am getting it right you want to get all samples even if the last batch is smaller than the rest, right?

Well using this toy example with 8 samples and using batch of 3:

import tensorflow as tf
import numpy as np

num_samples = 8
batch_size = 3
capacity = num_samples % batch_size # set the capacity to the actual remaining samples
data = np.arange(1, num_samples+1)
data_input = tf.constant(data)

batch = tf.train.batch([data_input], enqueue_many=True, batch_size=batch_size, capacity=capacity, allow_smaller_final_batch=True)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(2):
        print(i, sess.run([batch]))
    coord.request_stop()
    coord.join(threads)
    # this should follow a closed queue
    print(i+1, sess.run([batch]))

Outcome:

0 [array([1, 2, 3])]

1 [array([4, 5, 6])]

2 [array([7, 8])]

The important parameters here are enqueue_many so as to consider each number as a separate one and capacity which is set to the actual remaining samples (e.g. here is 2). If capacity is set to 1 you will get just 1sample and if it's 3 you will miss theallow_smaller_final_batch` flag effect since it will return 3 samples (the last one from the beginning).

hope this clarifies the way you should use allow_smaller_final_batch parameter.

like image 189
Eypros Avatar answered Jan 21 '26 09:01

Eypros