Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why detach needs to be called on variable in this example?

Tags:

pytorch

I was going through this example - https://github.com/pytorch/examples/blob/master/dcgan/main.py and I have a basic question.

fake = netG(noise)
label = Variable(label.fill_(fake_label))
output = netD(fake.detach()) # detach to avoid training G on these labels
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()

I understand that why we call detach() on variable fake, so that no gradients are computed for the Generator parameters. My question is, does it matter since optimizerD.step() is going to update the parameters associated with Discriminator only?

OptimizerD is defined as: optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

Besides, in the next step when we will update parameters for Generator, before that we will call netG.zero_grad() which eventually removes all previously computed gradients. Moreover, when we update parameters for G network, we do this - output = netD(fake). Here, we are not using detach. Why?

So, why detaching the variable (line 3) is necessary in the above code?

like image 733
Wasi Ahmad Avatar asked Oct 26 '17 01:10

Wasi Ahmad


Video Answer


2 Answers

ORIGINAL ANSWER (WRONG / INCOMPLETE)

You're right, optimizerD only updates netD and the gradients on netG are not used before netG.zero_grad() is called, so detaching is not necessary, it just saves time, because you're not computing gradients for the generator.

You're basically also answering your other question yourself, you don't detach fake in the second block because you specifically want to compute gradients on netG to be able to update its parameters.

Note how in the second block real_label is used as the corresponding label for fake, so if the discriminator finds the fake input to be real, the final loss is small, and vice versa, which is precisely what you want for the generator. Not sure if that's what confused you, but it's really the only difference compared to training the discriminator on fake inputs.

EDIT

Please see FatPanda's comment! My original answer is in fact incorrect. Pytorch destroys (parts of) the compute graph when .backward() is called. Without detaching before errD_fake.backward() the errG.backward() call later would not be able to backprop into the generator because the required graph is no longer available (unless you specify retain_graph=True). I'm relieved Soumith made the same mistake :D

like image 122
Jens Petersen Avatar answered Oct 22 '22 03:10

Jens Petersen


The top voted answer is INCORRECT/INCOMPLETE!

Check this: https://github.com/pytorch/examples/issues/116, and have a look at @plopd's answer:

This is not true. Detaching fake from the graph is necessary to avoid forward-passing the noise through G when we actually update the generator. If we do not detach, then, although fake is not needed for gradient update of D, it will still be added to the computational graph and as a consequence of backward pass which clears all the variables in the graph (retain_graph=False by default), fake won't be available when G is updated.

This post also clarifies a lot: https://zhuanlan.zhihu.com/p/43843694 (In Chinese).

like image 34
fatpanda2049 Avatar answered Oct 22 '22 05:10

fatpanda2049