Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to train generator from GAN?

After reading GAN tutorials and code samples i still don't understand how generator is trained. Let's say we have simple case: - generator input is noise and output is grayscale image 10x10 - discriminator input is image 10x10 and output is single value from 0 to 1 (fake or true)

Training discriminator is easy - take its output for real and expect 1 for it. Take output for fake and expect 0. We're working with real output size here - single value.

But training generator is different - we take fake output (1 value) and make expected output for that as one. But it sounds more like training of descriminator again. Output of generator is image 10x10 how can we train it with only 1 single value? How back propagation might work in this case?

like image 442
user1628106 Avatar asked Sep 14 '25 11:09

user1628106


1 Answers

To train the generator, you have to backpropagate through the entire combined model while freezing the weights of the discriminator, so that only the generator is updated.

For this, we have to compute d(g(z; θg); θd), where θg and θd are the weights of the generator and discriminator. To update the generator, we can compute the gradient wrt. to θg only ∂loss(d(g(z; θg); θd)) / ∂θg, and then update θg using normal gradient descent.

In Keras, this might look something like this (using the functional API):

genInput = Input(input_shape)
discriminator = ...
generator = ...

discriminator.trainable = True
discriminator.compile(...)

discriminator.trainable = False
combined = Model(genInput, discriminator(generator(genInput)))
combined.compile(...)

By setting trainable to False, already compiled models are not affected, only models compiled in the future are frozen. Thereby, the discriminator is trainable as a standalone model but frozen in the combined model.

Then, to train your GAN:

X_real = ...
noise = ...
X_gen = generator.predict(noise)

# This will only train the discriminator
loss_real = discriminator.train_on_batch(X_real, one_out)
loss_fake = discriminator.train_on_batch(X_gen, zero_out)

d_loss = 0.5 * np.add(loss_real, loss_fake)

noise = ...
# This will only train the generator.
g_loss = self.combined.train_on_batch(noise, one_out)
like image 97
Palle Avatar answered Sep 17 '25 20:09

Palle