Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow understanding tf.train.shuffle_batch

I have a single file of training data, about 100K rows, and I'm running a straightforward tf.train.GradientDescentOptimizer on each training step. The setup is essentially taken directly from Tensorflow's MNIST example. Code reproduced below:

x = tf.placeholder(tf.float32, [None, 21])
W = tf.Variable(tf.zeros([21, 2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 2])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

Given that I'm reading training data from a file, I'm using tf.train.string_input_producer and tf.decode_csv to read rows from the csv, and then tf.train.shuffle_batch to create batches that I then train on.

I'm confused as to what my parameters should be for tf.train.shuffle_batch. I read Tensorflow's documentation, and yet I'm still not sure what the "optimal" batch_size, capacity, and min_after_dequeue values are. Can anyone help shed some light on how I go about choosing proper values for these parameters, or link me to a resource where I can learn more? Thanks--

Here's the API link: https://www.tensorflow.org/versions/r0.9/api_docs/python/io_ops.html#shuffle_batch

like image 804
sir_thursday Avatar asked Oct 31 '22 01:10

sir_thursday


1 Answers

There is a little bit about the number of threads to use at

https://www.tensorflow.org/versions/r0.9/how_tos/reading_data/index.html#batching

Unfortunately, I don't think there is a simple answer for batch sizes. The efficient batch size for a network depends on a lot of details about the network. In practice, if you care about optimal performance you're going to need to do a bunch of trial and error (maybe starting from the values used by a similar network).

like image 200
Josh11b Avatar answered Nov 15 '22 08:11

Josh11b