I was going through this example - https://github.com/pytorch/examples/blob/master/dcgan/main.py and I have a basic question.
fake = netG(noise)
label = Variable(label.fill_(fake_label))
output = netD(fake.detach()) # detach to avoid training G on these labels
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
I understand that why we call detach()
on variable fake
, so that no gradients are computed for the Generator parameters. My question is, does it matter since optimizerD.step()
is going to update the parameters associated with Discriminator only?
OptimizerD is defined as:
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
Besides, in the next step when we will update parameters for Generator, before that we will call netG.zero_grad()
which eventually removes all previously computed gradients. Moreover, when we update parameters for G network, we do this - output = netD(fake)
. Here, we are not using detach. Why?
So, why detaching the variable (line 3) is necessary in the above code?
ORIGINAL ANSWER (WRONG / INCOMPLETE)
You're right, optimizerD
only updates netD
and the gradients on netG
are not used before netG.zero_grad()
is called, so detaching is not necessary, it just saves time, because you're not computing gradients for the generator.
You're basically also answering your other question yourself, you don't detach fake
in the second block because you specifically want to compute gradients on netG
to be able to update its parameters.
Note how in the second block real_label
is used as the corresponding label for fake
, so if the discriminator finds the fake input to be real, the final loss is small, and vice versa, which is precisely what you want for the generator. Not sure if that's what confused you, but it's really the only difference compared to training the discriminator on fake inputs.
EDIT
Please see FatPanda's comment! My original answer is in fact incorrect. Pytorch destroys (parts of) the compute graph when .backward()
is called. Without detaching before errD_fake.backward()
the errG.backward()
call later would not be able to backprop into the generator because the required graph is no longer available (unless you specify retain_graph=True
). I'm relieved Soumith made the same mistake :D
The top voted answer is INCORRECT/INCOMPLETE!
Check this: https://github.com/pytorch/examples/issues/116, and have a look at @plopd's answer:
This is not true. Detaching
fake
from the graph is necessary to avoid forward-passing the noise through G when we actually update the generator. If we do not detach, then, althoughfake
is not needed for gradient update of D, it will still be added to the computational graph and as a consequence ofbackward
pass which clears all the variables in the graph (retain_graph=False
by default),fake
won't be available when G is updated.
This post also clarifies a lot: https://zhuanlan.zhihu.com/p/43843694 (In Chinese).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With