Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Training GAN on small dataset of images

I have created a DCGAN and already trained it for CIFAR-10 dataset. Now, i would like to train it for custom dataset.

I have already gathered around 1200 images, it is practicly impossible to gather more. What should i do?

like image 717
Stefan Radonjic Avatar asked Feb 03 '18 21:02

Stefan Radonjic


1 Answers

We are going to post a paper in a coming week(s) about stochastic deconvolutions for generator, that can improve stability and variety for such a problem. If you are interested, I can send a current version of a paper right now. But generally speaking, the idea is simple:

  1. Build a classic GAN
  2. For deep layers of generator (let's say for a half of them) use stochastic deconvolutions (sdeconv)
  3. sdeconv is just a normal deconv layer, but filters are being selected on a fly randomly from a bank of filters. So your filter bank shape can be, for instance, (16, 128, 3, 3) where 16 - number of banks, 128 - number of filters in each, 3x3 - size. Your selection of a filter set at each training step is [random uniform 0-16, :, :, :]. Unselected filters remain untrained. In tensorflow you want to select different filter sets for a different images in batch as well as tf keeps training variables even if it is not asked for (we believe it is a bug, tf uses last known gradients for all variables even if they are not being used in a current dynamic sub-graph, so you have to utilize as much variables as you can).

That's it. Having 3 layers with sdeconv of 16 sets in each bank, practically you'll have 16x16x16 = 4096 combinations of different internal routes to produce an output. How is it helping on a small dataset? - Usually small datasets have relative large "topics" variance, but generally dataset is of one nature (photos of cats: all are realistc photos, but with different types of cats). In such datasets GAN collapses very quickly, however with sdeconv:

  1. Upper normal deconv layers learns how to reconstruct a style "realistic photo"
  2. Lower sdevond learns sub-distributions: "dark cat", "white cat", "red cat" and so on.
  3. Model can be seen as ensemble of weak-generators, each sub-generator is weak and can collapse, but will be "supported" by another sub-generator that temorarily outperforms discriminator.

MNIST is a great example of such a dataset: high "topics" variance, but the same style of digits.

GAN+weight norm+prelu (collapsed after 1000 steps, died after 2000, can only describe one "topic"):

GAN+weight norm+prelu+sdeconv, 4388 steps (local variety degradation of sub-topics is seen, however not collapsed globally, global visual variety preserved):

like image 177
azrev Avatar answered Oct 13 '22 00:10

azrev