Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Oscillating accuracy of CNN training with Tensor Flow for MNIST handwritten digits

I'm following the tutorial "Deep MNIST for Experts", https://www.tensorflow.org/versions/r0.11/tutorials/mnist/pros/index.html#deep-mnist-for-experts

Using Convolutional Neural Networks, I get an accuracy of 93.49%. Which is in fact low and I'm trying to improve it, but I have a doubt. According to the tutorial,

for i in range(20000):
   batch = mnist.train.next_batch(50)
   if i%100 == 0:
       train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
       print("step %d, training accuracy %g"%(i, train_accuracy))
   train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

The train accuracy is logged after every 100 iterations and seeing the accuracy, it keeps oscillating like increasing and then decreasing.

step 100, training accuracy 0.1
step 200, training accuracy 0.13
step 300, training accuracy 0.12
step 400, training accuracy 0.08
step 500, training accuracy 0.12
step 600, training accuracy 0.05
step 700, training accuracy 0.09
step 800, training accuracy 0.1
step 900, training accuracy 0.12
step 1000, training accuracy 0.09
step 1100, training accuracy 0.11
step 1200, training accuracy 0.09
step 1300, training accuracy 0.11
step 1400, training accuracy 0.06
step 1500, training accuracy 0.09
step 1600, training accuracy 0.14
step 1700, training accuracy 0.07
step 1800, training accuracy 0.08
......
step 19800, training accuracy 0.14
step 19900, training accuracy 0.07

Is there any reason for that? Or is it normal? Then why so? Also, what kind of variables I can change to improve the final accuracy? I've tried changing the learning rate variable already.

like image 896
Sera_Vinicit Avatar asked Oct 20 '16 13:10

Sera_Vinicit


1 Answers

Oscillating accuracy is typically caused by a learning_rate that is too high. My first tip would indeed be to lower the learning_rate, did you test multiple learning rates on a logarithmic scale, e.g. 0.1,0.05,0.02,0.01,0.005,0.002,...?

Using drastically smaller learning rates should remove the oscillating accuracy. Also check this answer on Kaggle and the linked document to get a better understanding.

EDIT:

Based on the remark in the comment: this accuracy is measured per batch. Since you are comparing the accuracies on different batches each time (a simple vs. a harder batch) it's normal that you don't get a monotonic increase in accuracy. You can further reduce the oscillations:

  • By increasing the batch size, the fluctuations should decrease: the impact of the difficulty of different examples will be averaged out.

  • You could also calculate the training accuracy over a constant set of examples:

    • Using a validation set

    • Averaging the batch accuracies over all batches in one epoch

    • Actually calculating the accuracy over all examples in the training set after each number of training steps. This off course has a big impact on the training time if you have a large training set.

like image 179
Fematich Avatar answered Sep 28 '22 06:09

Fematich