I am going through an open-source implementation of a domain-adversarial model (GAN-like). The implementation uses pytorch and I am not sure they use zero_grad()
correctly. They call zero_grad()
for the encoder optimizer (aka the generator) before updating the discriminator loss. However zero_grad()
is hardly documented, and I couldn't find information about it.
Here is a psuedo code comparing a standard GAN training (option 1), with their implementation (option 2). I think the second option is wrong, because it may accumulate the D_loss gradients with the E_opt. Can someone tell if these two pieces of code are equivalent?
Option 1 (a standard GAN implementation):
X, y = get_D_batch()
D_opt.zero_grad()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()
X, y = get_E_batch()
E_opt.zero_grad()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()
Option 2 (calling zero_grad()
for both optimizers at the beginning):
E_opt.zero_grad()
D_opt.zero_grad()
X, y = get_D_batch()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()
X, y = get_E_batch()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()
It depends on params
argument of torch.optim.Optimizer
subclasses (e.g. torch.optim.SGD
) and exact structure of the model.
Assuming E_opt
and D_opt
have different set of parameters (model.encoder
and model.decoder
do not share weights), something like this:
E_opt = torch.optim.Adam(model.encoder.parameters())
D_opt = torch.optim.Adam(model.decoder.parameters())
both options MIGHT indeed be equivalent (see commentary for your source code, additionally I have added backward()
which is really important here and also changed model
to discriminator
and generator
appropriately as I assume that's the case):
# Starting with zero gradient
E_opt.zero_grad()
D_opt.zero_grad()
# See comment below for possible cases
X, y = get_D_batch()
pred = discriminator(x)
D_loss = loss(pred, y)
# This will accumulate gradients in discriminator only
# OR in discriminator and generator, depends on other parts of code
# See below for commentary
D_loss.backward()
# Correct weights of discriminator
D_opt.step()
# This only relies on random noise input so discriminator
# Is not part of this equation
X, y = get_E_batch()
pred = generator(x)
E_loss = loss(pred, y)
E_loss.backward()
# So only parameters of generator are updated always
E_opt.step()
Now it's all about get_D_Batch
feeding data to discriminator.
This is not a problem as it does not involve generator, you pass real samples and only discriminator
takes part in this operation.
Here indeed gradient accumulation may occur. It would occur if get_D_batch
would simply call X = generator(noise)
and passed this data to discriminator
.
In such case both discriminator
and generator
have their gradients accumulated during backward()
as both are used.
We should take generator
out of the equation. Taken from PyTorch DCGan example there is a little line like this:
# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# DETACH HERE
output = discriminator(fake.detach()).view(-1)
What detach
does is it "stops" the gradient by detach
ing it from the computational graph. So gradients will not be backpropagated along this variable. This effectively does not impact gradients of generator
so it has no more gradients so no accumulation happens.
Another way (IMO better) would be to use with.torch.no_grad():
block like this:
# Generate fake image batch with G
with torch.no_grad():
fake = generator(noise)
label.fill_(fake_label)
# NO DETACH NEEDED
output = discriminator(fake).view(-1)
This way generator
operations will not build part of the graph so we get better performance (it would in the first case but would be detached afterwards).
Yeah, all in all first option is better for standard GANs as one doesn't have to think about such stuff (people implementing it should, but readers should not). Though there are also other approaches like single optimizer for both generator
and discriminator
(one cannot zero_grad()
only for subset of parameters (e.g. encoder
) in this case), weight sharing and others which further clutter the picture.
with torch.no_grad()
should alleviate the problem in all/most cases as far as I can tell and imagine ATM.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With