Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Understanding when to call zero_grad() in pytorch, when training with multiple losses

I am going through an open-source implementation of a domain-adversarial model (GAN-like). The implementation uses pytorch and I am not sure they use zero_grad() correctly. They call zero_grad() for the encoder optimizer (aka the generator) before updating the discriminator loss. However zero_grad() is hardly documented, and I couldn't find information about it.

Here is a psuedo code comparing a standard GAN training (option 1), with their implementation (option 2). I think the second option is wrong, because it may accumulate the D_loss gradients with the E_opt. Can someone tell if these two pieces of code are equivalent?

Option 1 (a standard GAN implementation):

X, y = get_D_batch()
pred = model(X)
D_loss = loss(pred, y)

X, y = get_E_batch()
pred = model(X)
E_loss = loss(pred, y)

Option 2 (calling zero_grad() for both optimizers at the beginning):


X, y = get_D_batch()
pred = model(X)
D_loss = loss(pred, y)

X, y = get_E_batch()
pred = model(X)
E_loss = loss(pred, y)
like image 1000
Yuval Atzmon Avatar asked Apr 07 '20 10:04

Yuval Atzmon

1 Answers

It depends on params argument of torch.optim.Optimizer subclasses (e.g. torch.optim.SGD) and exact structure of the model.

Assuming E_opt and D_opt have different set of parameters (model.encoder and model.decoder do not share weights), something like this:

E_opt = torch.optim.Adam(model.encoder.parameters())
D_opt = torch.optim.Adam(model.decoder.parameters())

both options MIGHT indeed be equivalent (see commentary for your source code, additionally I have added backward() which is really important here and also changed model to discriminator and generator appropriately as I assume that's the case):

# Starting with zero gradient

# See comment below for possible cases
X, y = get_D_batch()
pred = discriminator(x)
D_loss = loss(pred, y)
# This will accumulate gradients in discriminator only
# OR in discriminator and generator, depends on other parts of code
# See below for commentary
# Correct weights of discriminator

# This only relies on random noise input so discriminator
# Is not part of this equation
X, y = get_E_batch()
pred = generator(x)
E_loss = loss(pred, y)
# So only parameters of generator are updated always

Now it's all about get_D_Batch feeding data to discriminator.

Case 1 - real samples

This is not a problem as it does not involve generator, you pass real samples and only discriminator takes part in this operation.

Case 2 - generated samples

Naive case

Here indeed gradient accumulation may occur. It would occur if get_D_batch would simply call X = generator(noise) and passed this data to discriminator. In such case both discriminator and generator have their gradients accumulated during backward() as both are used.

Correct case

We should take generator out of the equation. Taken from PyTorch DCGan example there is a little line like this:

# Generate fake image batch with G
fake = generator(noise)
output = discriminator(fake.detach()).view(-1)

What detach does is it "stops" the gradient by detaching it from the computational graph. So gradients will not be backpropagated along this variable. This effectively does not impact gradients of generator so it has no more gradients so no accumulation happens.

Another way (IMO better) would be to use with.torch.no_grad(): block like this:

# Generate fake image batch with G
with torch.no_grad():
    fake = generator(noise)
output = discriminator(fake).view(-1)

This way generator operations will not build part of the graph so we get better performance (it would in the first case but would be detached afterwards).


Yeah, all in all first option is better for standard GANs as one doesn't have to think about such stuff (people implementing it should, but readers should not). Though there are also other approaches like single optimizer for both generator and discriminator (one cannot zero_grad() only for subset of parameters (e.g. encoder) in this case), weight sharing and others which further clutter the picture.

with torch.no_grad() should alleviate the problem in all/most cases as far as I can tell and imagine ATM.

like image 187
Szymon Maszke Avatar answered Oct 20 '22 06:10

Szymon Maszke