Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

When training GANs in Keras, are multiple passes required to optimize the generator and discriminator?

I'm more familiar with tensorflow graph training than Keras, but I'm trying out Keras here.

In building a GAN the generator needs be optimized against a different loss than the discriminator (the opposite loss). In base tensorflow this is easy enough to implement using either 2 optimizers or by calling optimizer.compute_gradients(...) and optimizer.apply_gradients(...) separately with the appropriate group of weights.

In Keras, I don't see that I can achieve either of these. In implementations such as Keras-GAN, it appears that the training of generator and discriminator are split into separate models and then trained independently batch-by-batch. This means many more passes are required per effective update than would be required with the base tensorflow implementation with two optimizers operating on one pass.

Is there a way to implement the optimizer for GANs so that both generator and discriminator get trained in one pass in Keras?

TF 1.14

like image 379
David Parks Avatar asked Jul 09 '19 23:07

David Parks


1 Answers

This is a really tough question for Keras for several reasons:

  1. A model can only have one optimizer... it would be necessary to change the source code for it to accept two or more

  2. Even when you are using a custom optimizer, it would be possible to separate the weights, but it doesn't offer support to separate the losses, as can be seen in the source code for optimizers. The probability is that the optimizer already computes a final common loss (which would then make it impossible to attribute one loss for a group of weights and another for the other group)

  3. The training mechanisms are not easy to find in the code. Things are spread all around, supporting many things such as loss weights, sample weights, etc. The time that it would take to summarize everything and then decide what to do/change would be too much.

Answer suggestion

Make your model in Keras as you would. The discriminator, the generator, their connections and outputs.

Just don't compile it. Instead, keep track of the main tensors (generator output, discriminator output, generator input), create the loss functions in Tensorflow style and train everything in tensorflow style.

like image 198
Daniel Möller Avatar answered Nov 15 '22 18:11

Daniel Möller