Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to overcome overfitting in CNN - standard methods don't work

I've been recently playing around with car data set from Stanford (http://ai.stanford.edu/~jkrause/cars/car_dataset.html). From the very beginning I had an overfitting problem so decided to:

  1. Add regularization (L2, dropout, batch norm, ...)
  2. Tried different architectures (VGG16, VGG19, InceptionV3, DenseNet121, ...)
  3. Tried trasnfer learning using models trained on ImageNet
  4. Used data augmentation

Every step moved me a little bit forward. However I finished with 50% validation accuracy (started below 20%) compared to 99% train accuracy.

Do you have an idea what more can I do to get to around 80-90% accuracy?

Hope this can help some people!:)

like image 968
Michał Gdak Avatar asked Feb 04 '18 12:02

Michał Gdak


People also ask

What is the best solution to remove the problem of overfitting?

Remove layers / number of units per layer (model) As mentioned in L1 or L2 regularization, an over-complex model may more likely overfit. Therefore, we can directly reduce the model's complexity by removing layers and reduce the size of our model.

Which technique is used to avoid overfitting of a model?

Cross-validation is a powerful preventative measure against overfitting. The idea is clever: Use your initial training data to generate multiple mini train-test splits. Use these splits to tune your model. In standard k-fold cross-validation, we partition the data into k subsets, called folds.


1 Answers

Things you should try include:

  • Early stopping, i.e. use a portion of your data to monitor validation loss and stop training if performance does not improve for some epochs.
  • Check whether you have unbalanced classes, use class weighting to equally represent each class in the data.
  • Regularization parameter tuning: different l2 coefficients, different dropout values, different regularization constraints (e.g. l1).

Other general suggestions may be to try and replicate the state of the art models on this particular dataset, see if those perform as they should.
Also make sure to have all implementation details ironed out (e.g. convolution is being performed along width and height, and not along the channels dimension - this is a classic rookie mistake when starting out with Keras, for instance).

It would also help to have some more details on the code that you are using, but for now these suggestions will do.
50% accuracy on a 200-class problem doesn't sound so bad anyway.

Cheers

like image 57
Daniele Grattarola Avatar answered Sep 24 '22 01:09

Daniele Grattarola