I'm investigating the use of a Wasserstein GAN with gradient penalty in PyTorch, but consistently get large, positive generator losses that increase over epochs.
I'm heavily borrowing from Caogang's implementation, but am using the discriminator and generator losses used in this implementation because I get Invalid gradient at index 0 - expected shape[] but got [1] if I try to call .backward() with the one and mone args used in the Caogang implementation.
I'm training on a augmented WikiArt dataset (>400k 64x64 images) and CIFAR-10, and have gotten a normal WGAN (with weight clipping to work) [i.e. it produces passable images after 25 epochs], despite the fact that the D and G losses both hover around 3 [I calculate them using torch.mean(D_real) etc.] for all epochs. However, in the WGAN-GP version, the generator loss increases dramatically on both the WikiArt and CIFAR-10 datasets, and completely fails to generate anything other than noise on WikiArt.
Here's an example of the loss after 25 epochs on CIFAR-10:

I don't use any tricks like one-sided label smoothing, and I train with the default learning rate of 0.001, the Adam optimizer and I train the discriminator 5 times for every generator update. Why does this crazy loss behaviour happen, and why does the normal weight-clipping WGAN still 'work' on WikiArt but WGANGP completely fail?
This happens irrespective of the structure, whether both G and D are DCGANs or when using this modified DCGAN, the Creative Adversarial Network, which requires that D be able to classify images and G generate ambiguous images.
Below is the relevant part of my current trainmethod:
self.generator = Can64Generator(self.z_noise, self.channels, self.num_gen_filters).to(self.device)
self.discriminator =WCan64Discriminator(self.channels,self.y_dim, self.num_disc_filters).to(self.device)
style_criterion = nn.CrossEntropyLoss()
self.disc_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))
self.gen_optimizer = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))
while i < len(dataloader):
j = 0
disc_loss_epoch = []
gen_loss_epoch = []
if self.type == "can":
disc_class_loss_epoch = []
gen_class_loss_epoch = []
if self.gradient_penalty == False:
# critic training methodology in official WGAN implementation
if gen_iterations < 25 or (gen_iterations % 500 == 0):
disc_iters = 100
else:
disc_iters = self.disc_iterations
while j < disc_iters and (i < len(dataloader)):
# if using wgan with weight clipping
if self.gradient_penalty == False:
# Train Discriminator
for param in self.discriminator.parameters():
param.data.clamp_(self.lower_clamp,self.upper_clamp)
for param in self.discriminator.parameters():
param.requires_grad_(True)
j+=1
i+=1
data = data_iterator.next()
self.discriminator.zero_grad()
real_images, image_labels = data
# image labels are the the image's classes (e.g. Impressionism)
real_images = real_images.to(self.device)
batch_size = real_images.size(0)
real_image_labels = torch.LongTensor(batch_size).to(self.device)
real_image_labels.copy_(image_labels)
labels = torch.full((batch_size,),real_label,device=self.device)
if self.type == 'can':
predicted_output_real, predicted_styles_real = self.discriminator(real_images.detach())
predicted_styles_real = predicted_styles_real.to(self.device)
disc_class_loss = style_criterion(predicted_styles_real,real_image_labels)
disc_class_loss.backward(retain_graph=True)
else:
predicted_output_real = self.discriminator(real_images.detach())
disc_loss_real = -torch.mean(predicted_output_real)
# fake
noise = torch.randn(batch_size,self.z_noise,1,1,device=self.device)
with torch.no_grad():
noise_g = noise.detach()
fake_images = self.generator(noise_g)
labels.fill_(fake_label)
if self.type == 'can':
predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images)
else:
predicted_output_fake = self.discriminator(fake_images)
disc_gen_z_1 = predicted_output_fake.mean().item()
disc_loss_fake = torch.mean(predicted_output_fake)
#via https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/WGAN_GP.py
if self.gradient_penalty:
# gradient penalty
alpha = torch.rand((real_images.size()[0], 1, 1, 1)).to(self.device)
x_hat = alpha * real_images.data + (1 - alpha) * fake_images.data
x_hat.requires_grad_(True)
if self.type == 'can':
pred_hat, _ = self.discriminator(x_hat)
else:
pred_hat = self.discriminator(x_hat)
gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(self.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()
disc_loss = disc_loss_fake + disc_loss_real + gradient_penalty
else:
disc_loss = disc_loss_fake + disc_loss_real
if self.type == 'can':
disc_loss += disc_class_loss.mean()
disc_x = disc_loss.mean().item()
disc_loss.backward(retain_graph=True)
self.disc_optimizer.step()
# train generator
for param in self.discriminator.parameters():
param.requires_grad_(False)
self.generator.zero_grad()
labels.fill_(real_label)
if self.type == 'can':
predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images)
predicted_styles_fake = predicted_styles_fake.to(self.device)
else:
predicted_output_fake = self.discriminator(fake_images)
gen_loss = -torch.mean(predicted_output_fake)
disc_gen_z_2 = gen_loss.mean().item()
if self.type == 'can':
fake_batch_labels = 1.0/self.y_dim * torch.ones_like(predicted_styles_fake)
fake_batch_labels = torch.mean(fake_batch_labels,1).long().to(self.device)
gen_class_loss = style_criterion(predicted_styles_fake,fake_batch_labels)
gen_class_loss.backward(retain_graph=True)
gen_loss += gen_class_loss.mean()
gen_loss.backward()
gen_iterations += 1
This is the code for the (DCGAN) generator:
class Can64Generator(nn.Module):
def __init__(self, z_noise, channels, num_gen_filters):
super(Can64Generator,self).__init__()
self.ngpu = 1
self.main = nn.Sequential(
nn.ConvTranspose2d(z_noise, num_gen_filters * 16, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_gen_filters * 16),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 16, num_gen_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters * 4),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 4, num_gen_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters * 2),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 2, num_gen_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, inp):
output = self.main(inp)
return output
And this is the (current) CAN discriminator, which has extra layers for style (image class) classification):
class Can64Discriminator(nn.Module):
def __init__(self, channels,y_dim, num_disc_filters):
super(Can64Discriminator, self).__init__()
self.ngpu = 1
self.conv = nn.Sequential(
nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_disc_filters * 8),
nn.LeakyReLU(0.2, inplace=True),
)
# was this
#self.final_conv = nn.Conv2d(num_disc_filters * 8, num_disc_filters * 8, 4, 2, 1, bias=False)
self.real_fake_head = nn.Linear(num_disc_filters * 8, 1)
# no bn and lrelu needed
self.sig = nn.Sigmoid()
self.fc = nn.Sequential()
self.fc.add_module("linear_layer{0}".format(num_disc_filters*16),nn.Linear(num_disc_filters*8,num_disc_filters*16))
self.fc.add_module("linear_layer{0}".format(num_disc_filters*8),nn.Linear(num_disc_filters*16,num_disc_filters*8))
self.fc.add_module("linear_layer{0}".format(num_disc_filters),nn.Linear(num_disc_filters*8,y_dim))
self.fc.add_module('softmax',nn.Softmax(dim=1))
def forward(self, inp):
x = self.conv(inp)
x = x.view(x.size(0),-1)
real_out = self.sig(self.real_fake_head(x))
real_out = real_out.view(-1,1).squeeze(1)
style = self.fc(x)
#style = torch.mean(style,1) # CrossEntropyLoss requires input be (N,C)
return real_out,style
The only differences between the WGANGP version and the WGAN version of my GAN is the WGAN version uses RMSprop with lr=0.00005 and clips the weights of the discriminator, as per the WGAN paper.
What could be causing this? I'd like to make as minimal change as possible, as I want to compare loss functions alone. The same problem is encountered even when using an unmodified DCGAN discriminator on CIFAR-10. Am I encountering this perhaps because I am only training currently for 25 epochs, or is there another reason? Interestingly, my GAN also completely fails to generate anything other than noise when using LSGAN (nn.MSELoss()).
Thanks in advance!
Batch Normalization in the discriminator breaks Wasserstein GANs with gradient penalty. The authors themselves advocate the usage of layer normalization instead, but this is clearly written in bold in their paper (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf). It is hard to say if there are other bugs in your code, but I urge you to thoroughly read the DCGAN and the Wasserstein GAN paper and really take notes on the hyperparameters. Getting them wrong really destroys the performance of the GAN and doing a hyperparameter search gets expensive quite quickly.
By the way transposed convolutions produce stairway artifacts in your output images. Use image resizing instead. For an indepth explanation of that phenomenon I can recommend the following resource (https://distill.pub/2016/deconv-checkerboard/).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With