Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

After some number of epochs fake image creation become worst in GAN

I'm trying to create GAN model. This is my discriminator.py

import torch.nn as nn
class D(nn.Module):
    feature_maps = 64
    kernel_size = 4
    stride = 2
    padding = 1
    bias = False
    inplace = True

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(4, self.feature_maps, self.kernel_size, self.stride, self.padding, bias=self.bias),
            nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps, self.feature_maps * 2, self.kernel_size, self.stride, self.padding,
                      bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * 2), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * 2, self.feature_maps * (2 * 2), self.kernel_size, self.stride, self.padding,
                      bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * (2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * (2 * 2), self.feature_maps * (2 * 2 * 2), self.kernel_size, self.stride,
                      self.padding, bias=self.bias),
            nn.BatchNorm2d(self.feature_maps * (2 * 2 * 2)), nn.LeakyReLU(0.2, inplace=self.inplace),
            nn.Conv2d(self.feature_maps * (2 * 2 * 2), 1, self.kernel_size, 1, 0, bias=self.bias),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)

this is my generator.py

import torch.nn as nn
class G(nn.Module):
    feature_maps = 512
    kernel_size = 4
    stride = 2
    padding = 1
    bias = False

    def __init__(self, input_vector):
        super(G, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(input_vector, self.feature_maps, self.kernel_size, 1, 0, bias=self.bias),
            nn.BatchNorm2d(self.feature_maps), nn.ReLU(True),
            nn.ConvTranspose2d(self.feature_maps, int(self.feature_maps // 2), self.kernel_size, self.stride, self.padding,
                               bias=self.bias),
            nn.BatchNorm2d(int(self.feature_maps // 2)), nn.ReLU(True),
            nn.ConvTranspose2d(int(self.feature_maps // 2), int((self.feature_maps // 2) // 2), self.kernel_size, self.stride,
                               self.padding,
                               bias=self.bias),
            nn.BatchNorm2d(int((self.feature_maps // 2) // 2)), nn.ReLU(True),
            nn.ConvTranspose2d((int((self.feature_maps // 2) // 2)), int(((self.feature_maps // 2) // 2) // 2), self.kernel_size,
                               self.stride, self.padding,
                               bias=self.bias),
            nn.BatchNorm2d(int((self.feature_maps // 2) // 2) // 2), nn.ReLU(True),
            nn.ConvTranspose2d(int(((self.feature_maps // 2) // 2) // 2), 4, self.kernel_size, self.stride, self.padding,
                               bias=self.bias),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output

This is my gans.py

# Importing the libraries
from __future__ import print_function
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from generator import G
from discriminator import D
import os
from PIL import Image

batchSize = 64  # We set the size of the batch.
imageSize = 64  # We set the size of the generated images (64x64).
input_vector = 100
nb_epochs = 500
# Creating the transformations
transform = transforms.Compose([transforms.Resize((imageSize, imageSize)), transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5,
                                                                       0.5)), ])  # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.


def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGBA')


# Loading the dataset
dataset = dset.ImageFolder(root='./data', transform=transform, loader=pil_loader_rgba)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True,
                                         num_workers=2)  # We use dataLoader to get the images of the training set batch by batch.


# Defining the weights_init function that takes as input a neural network m and that will initialize all its weights.
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def is_cuda_available():
    return torch.cuda.is_available()


def is_gpu_available():
    if is_cuda_available():
        if int(torch.cuda.device_count()) > 0:
            return True
        return False
    return False


# Create results directory
def create_dir(name):
    if not os.path.exists(name):
        os.makedirs(name)


# Creating the generator
netG = G(input_vector)
netG.apply(weights_init)

# Creating the discriminator
netD = D()
netD.apply(weights_init)

if is_gpu_available():
    netG.cuda()
    netD.cuda()

# Training the DCGANs

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

generator_model = 'generator_model'
discriminator_model = 'discriminator_model'


def save_model(epoch, model, optimizer, error, filepath, noise=None):
    if os.path.exists(filepath):
        os.remove(filepath)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': error,
        'noise': noise
    }, filepath)


def load_checkpoint(filepath):
    if os.path.exists(filepath):
        return torch.load(filepath)
    return None


def main():
    print("Device name : " + torch.cuda.get_device_name(0))
    for epoch in range(nb_epochs):

        for i, data in enumerate(dataloader, 0):
            checkpointG = load_checkpoint(generator_model)
            checkpointD = load_checkpoint(discriminator_model)
            if checkpointG:
                netG.load_state_dict(checkpointG['model_state_dict'])
                optimizerG.load_state_dict(checkpointG['optimizer_state_dict'])
            if checkpointD:
                netD.load_state_dict(checkpointD['model_state_dict'])
                optimizerD.load_state_dict(checkpointD['optimizer_state_dict'])

            # 1st Step: Updating the weights of the neural network of the discriminator

            netD.zero_grad()

            # Training the discriminator with a real image of the dataset
            real, _ = data
            if is_gpu_available():
                input = Variable(real.cuda()).cuda()
                target = Variable(torch.ones(input.size()[0]).cuda()).cuda()
            else:
                input = Variable(real)
                target = Variable(torch.ones(input.size()[0]))
            output = netD(input)
            errD_real = criterion(output, target)

            # Training the discriminator with a fake image generated by the generator
            if is_gpu_available():
                noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1)).cuda()
                target = Variable(torch.zeros(input.size()[0])).cuda()
            else:
                noise = Variable(torch.randn(input.size()[0], input_vector, 1, 1))
                target = Variable(torch.zeros(input.size()[0]))
            fake = netG(noise)
            output = netD(fake.detach())
            errD_fake = criterion(output, target)

            # Backpropagating the total error
            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            # 2nd Step: Updating the weights of the neural network of the generator
            netG.zero_grad()
            if is_gpu_available():
                target = Variable(torch.ones(input.size()[0])).cuda()
            else:
                target = Variable(torch.ones(input.size()[0]))
            output = netD(fake)
            errG = criterion(output, target)
            errG.backward()
            optimizerG.step()

            # 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (
            epoch, nb_epochs, i, len(dataloader), errD.data, errG.data))
            save_model(epoch, netG, optimizerG, errG, generator_model, noise)
            save_model(epoch, netD, optimizerD, errD, discriminator_model, noise)

            if i % 100 == 0:
                create_dir('results')
                vutils.save_image(real, '%s/real_samples.png' % "./results", normalize=True)
                fake = netG(noise)
                vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize=True)


if __name__ == "__main__":
    main()

So AFTER few hours I decided to look at my results folder. I saw weird thing AFTER 39th epoch. Generator started generating worst images. Until 39th epoch generator IMPROVED. Pls look at below Screenshot. enter image description here

Why generator suddenly became worst ? I'm trying to run 500 epochs. I thought more epochs more success

So I had a look at logs and I'm seeing below

[40/500][0/157] Loss_D: 0.0141 Loss_G: 5.7559
[40/500][1/157] Loss_D: 0.0438 Loss_G: 5.5805
[40/500][2/157] Loss_D: 0.0161 Loss_G: 6.4947
[40/500][3/157] Loss_D: 0.0138 Loss_G: 7.1711
[40/500][4/157] Loss_D: 0.0547 Loss_G: 4.6262
[40/500][5/157] Loss_D: 0.0295 Loss_G: 4.7831
[40/500][6/157] Loss_D: 0.0103 Loss_G: 6.3700
[40/500][7/157] Loss_D: 0.0276 Loss_G: 5.9162
[40/500][8/157] Loss_D: 0.0205 Loss_G: 6.3571
[40/500][9/157] Loss_D: 0.0139 Loss_G: 6.4961
[40/500][10/157] Loss_D: 0.0117 Loss_G: 6.4371
[40/500][11/157] Loss_D: 0.0057 Loss_G: 6.6858
[40/500][12/157] Loss_D: 0.0203 Loss_G: 5.4308
[40/500][13/157] Loss_D: 0.0078 Loss_G: 6.5749
[40/500][14/157] Loss_D: 0.0115 Loss_G: 6.3202
[40/500][15/157] Loss_D: 0.0187 Loss_G: 6.2258
[40/500][16/157] Loss_D: 0.0052 Loss_G: 6.5253
[40/500][17/157] Loss_D: 0.0158 Loss_G: 5.5672
[40/500][18/157] Loss_D: 0.0156 Loss_G: 5.5416
[40/500][19/157] Loss_D: 0.0306 Loss_G: 5.4550
[40/500][20/157] Loss_D: 0.0077 Loss_G: 6.1985
[40/500][21/157] Loss_D: 0.0158 Loss_G: 5.3092
[40/500][22/157] Loss_D: 0.0167 Loss_G: 5.8395
[40/500][23/157] Loss_D: 0.0119 Loss_G: 6.0849
[40/500][24/157] Loss_D: 0.0104 Loss_G: 6.5493
[40/500][25/157] Loss_D: 0.0182 Loss_G: 5.6758
[40/500][26/157] Loss_D: 0.0145 Loss_G: 5.8336
[40/500][27/157] Loss_D: 0.0050 Loss_G: 6.8472
[40/500][28/157] Loss_D: 0.0080 Loss_G: 6.4894
[40/500][29/157] Loss_D: 0.0186 Loss_G: 5.5563
[40/500][30/157] Loss_D: 0.0143 Loss_G: 6.4144
[40/500][31/157] Loss_D: 0.0377 Loss_G: 5.4557
[40/500][32/157] Loss_D: 0.0540 Loss_G: 4.6034
[40/500][33/157] Loss_D: 0.0200 Loss_G: 5.6417
[40/500][34/157] Loss_D: 0.0189 Loss_G: 5.7760
[40/500][35/157] Loss_D: 0.0197 Loss_G: 6.1732
[40/500][36/157] Loss_D: 0.0093 Loss_G: 6.4046
[40/500][37/157] Loss_D: 0.0281 Loss_G: 5.5217
[40/500][38/157] Loss_D: 0.0410 Loss_G: 5.9157
[40/500][39/157] Loss_D: 0.0667 Loss_G: 5.2522
[40/500][40/157] Loss_D: 0.0530 Loss_G: 5.6412
[40/500][41/157] Loss_D: 0.0315 Loss_G: 5.9325
[40/500][42/157] Loss_D: 0.0097 Loss_G: 6.7819
[40/500][43/157] Loss_D: 0.0157 Loss_G: 5.8630
[40/500][44/157] Loss_D: 0.0382 Loss_G: 5.1942
[40/500][45/157] Loss_D: 0.0331 Loss_G: 5.1490
[40/500][46/157] Loss_D: 0.0362 Loss_G: 5.7026
[40/500][47/157] Loss_D: 0.0237 Loss_G: 5.7493
[40/500][48/157] Loss_D: 0.0227 Loss_G: 5.7636
[40/500][49/157] Loss_D: 0.0230 Loss_G: 5.6500
[40/500][50/157] Loss_D: 0.0329 Loss_G: 5.4542
[40/500][51/157] Loss_D: 0.0306 Loss_G: 5.6473
[40/500][52/157] Loss_D: 0.0254 Loss_G: 5.8464
[40/500][53/157] Loss_D: 0.0402 Loss_G: 5.8609
[40/500][54/157] Loss_D: 0.0242 Loss_G: 5.9952
[40/500][55/157] Loss_D: 0.0400 Loss_G: 5.8378
[40/500][56/157] Loss_D: 0.0302 Loss_G: 5.8990
[40/500][57/157] Loss_D: 0.0239 Loss_G: 5.8134
[40/500][58/157] Loss_D: 0.0348 Loss_G: 5.8109
[40/500][59/157] Loss_D: 0.0361 Loss_G: 5.9011
[40/500][60/157] Loss_D: 0.0418 Loss_G: 5.8825
[40/500][61/157] Loss_D: 0.0501 Loss_G: 6.2302
[40/500][62/157] Loss_D: 0.0184 Loss_G: 6.2755
[40/500][63/157] Loss_D: 0.0273 Loss_G: 5.9655
[40/500][64/157] Loss_D: 0.0250 Loss_G: 5.7513
[40/500][65/157] Loss_D: 0.0298 Loss_G: 6.0434
[40/500][66/157] Loss_D: 0.0299 Loss_G: 6.4280
[40/500][67/157] Loss_D: 0.0205 Loss_G: 6.3743
[40/500][68/157] Loss_D: 0.0173 Loss_G: 6.2749
[40/500][69/157] Loss_D: 0.0199 Loss_G: 6.0541
[40/500][70/157] Loss_D: 0.0309 Loss_G: 6.5044
[40/500][71/157] Loss_D: 0.0177 Loss_G: 6.6093
[40/500][72/157] Loss_D: 0.0363 Loss_G: 7.2993
[40/500][73/157] Loss_D: 0.0093 Loss_G: 7.6995
[40/500][74/157] Loss_D: 0.0087 Loss_G: 7.3493
[40/500][75/157] Loss_D: 0.0540 Loss_G: 8.2688
[40/500][76/157] Loss_D: 0.0172 Loss_G: 8.3312
[40/500][77/157] Loss_D: 0.0086 Loss_G: 7.6863
[40/500][78/157] Loss_D: 0.0232 Loss_G: 7.4930
[40/500][79/157] Loss_D: 0.0175 Loss_G: 7.8834
[40/500][80/157] Loss_D: 0.0109 Loss_G: 9.5329
[40/500][81/157] Loss_D: 0.0093 Loss_G: 7.3253
[40/500][82/157] Loss_D: 0.0674 Loss_G: 10.6709
[40/500][83/157] Loss_D: 0.0010 Loss_G: 10.8321
[40/500][84/157] Loss_D: 0.0083 Loss_G: 8.5728
[40/500][85/157] Loss_D: 0.0124 Loss_G: 6.9085
[40/500][86/157] Loss_D: 0.0181 Loss_G: 7.0867
[40/500][87/157] Loss_D: 0.0130 Loss_G: 7.3527
[40/500][88/157] Loss_D: 0.0189 Loss_G: 7.2494
[40/500][89/157] Loss_D: 0.0302 Loss_G: 8.7555
[40/500][90/157] Loss_D: 0.0147 Loss_G: 7.7668
[40/500][91/157] Loss_D: 0.0325 Loss_G: 7.7779
[40/500][92/157] Loss_D: 0.0257 Loss_G: 8.3955
[40/500][93/157] Loss_D: 0.0113 Loss_G: 8.3687
[40/500][94/157] Loss_D: 0.0124 Loss_G: 7.6081
[40/500][95/157] Loss_D: 0.0088 Loss_G: 7.6012
[40/500][96/157] Loss_D: 0.0241 Loss_G: 7.6573
[40/500][97/157] Loss_D: 0.0522 Loss_G: 10.8114
[40/500][98/157] Loss_D: 0.0071 Loss_G: 11.0529
[40/500][99/157] Loss_D: 0.0043 Loss_G: 8.0707
[40/500][100/157] Loss_D: 0.0141 Loss_G: 7.2864
[40/500][101/157] Loss_D: 0.0234 Loss_G: 7.3585
[40/500][102/157] Loss_D: 0.0148 Loss_G: 7.4577
[40/500][103/157] Loss_D: 0.0190 Loss_G: 8.1904
[40/500][104/157] Loss_D: 0.0201 Loss_G: 8.1518
[40/500][105/157] Loss_D: 0.0220 Loss_G: 9.1069
[40/500][106/157] Loss_D: 0.0108 Loss_G: 9.0069
[40/500][107/157] Loss_D: 0.0044 Loss_G: 8.0970
[40/500][108/157] Loss_D: 0.0076 Loss_G: 7.2699
[40/500][109/157] Loss_D: 0.0052 Loss_G: 7.4036
[40/500][110/157] Loss_D: 0.0167 Loss_G: 7.2742
[40/500][111/157] Loss_D: 0.0032 Loss_G: 7.9825
[40/500][112/157] Loss_D: 0.3462 Loss_G: 32.6314
[40/500][113/157] Loss_D: 0.1704 Loss_G: 40.6010
[40/500][114/157] Loss_D: 0.0065 Loss_G: 44.4607
[40/500][115/157] Loss_D: 0.0142 Loss_G: 43.9761
[40/500][116/157] Loss_D: 0.0160 Loss_G: 45.0376
[40/500][117/157] Loss_D: 0.0042 Loss_G: 45.9534
[40/500][118/157] Loss_D: 0.0061 Loss_G: 45.2998
[40/500][119/157] Loss_D: 0.0023 Loss_G: 45.4654
[40/500][120/157] Loss_D: 0.0033 Loss_G: 44.6643
[40/500][121/157] Loss_D: 0.0042 Loss_G: 44.6020
[40/500][122/157] Loss_D: 0.0002 Loss_G: 44.4807
[40/500][123/157] Loss_D: 0.0004 Loss_G: 44.0402
[40/500][124/157] Loss_D: 0.0055 Loss_G: 43.9188
[40/500][125/157] Loss_D: 0.0021 Loss_G: 43.1988
[40/500][126/157] Loss_D: 0.0008 Loss_G: 41.6770
[40/500][127/157] Loss_D: 0.0001 Loss_G: 40.8719
[40/500][128/157] Loss_D: 0.0009 Loss_G: 40.3803
[40/500][129/157] Loss_D: 0.0023 Loss_G: 39.0143
[40/500][130/157] Loss_D: 0.0254 Loss_G: 39.0317
[40/500][131/157] Loss_D: 0.0008 Loss_G: 37.9451
[40/500][132/157] Loss_D: 0.0253 Loss_G: 37.1046
[40/500][133/157] Loss_D: 0.0046 Loss_G: 36.2807
[40/500][134/157] Loss_D: 0.0025 Loss_G: 35.5878
[40/500][135/157] Loss_D: 0.0011 Loss_G: 33.6500
[40/500][136/157] Loss_D: 0.0061 Loss_G: 33.5011
[40/500][137/157] Loss_D: 0.0015 Loss_G: 30.0363
[40/500][138/157] Loss_D: 0.0019 Loss_G: 31.0197
[40/500][139/157] Loss_D: 0.0027 Loss_G: 28.4693
[40/500][140/157] Loss_D: 0.0189 Loss_G: 27.3072
[40/500][141/157] Loss_D: 0.0051 Loss_G: 26.6637
[40/500][142/157] Loss_D: 0.0077 Loss_G: 24.8390
[40/500][143/157] Loss_D: 0.0123 Loss_G: 23.8334
[40/500][144/157] Loss_D: 0.0014 Loss_G: 23.3755
[40/500][145/157] Loss_D: 0.0036 Loss_G: 19.6341
[40/500][146/157] Loss_D: 0.0025 Loss_G: 18.1076
[40/500][147/157] Loss_D: 0.0029 Loss_G: 16.9415
[40/500][148/157] Loss_D: 0.0028 Loss_G: 16.4647
[40/500][149/157] Loss_D: 0.0048 Loss_G: 14.6184
[40/500][150/157] Loss_D: 0.0074 Loss_G: 13.2544
[40/500][151/157] Loss_D: 0.0053 Loss_G: 13.0052
[40/500][152/157] Loss_D: 0.0070 Loss_G: 11.8815
[40/500][153/157] Loss_D: 0.0078 Loss_G: 12.1657
[40/500][154/157] Loss_D: 0.0094 Loss_G: 10.4259
[40/500][155/157] Loss_D: 0.0073 Loss_G: 9.9345
[40/500][156/157] Loss_D: 0.0082 Loss_G: 9.7609
[41/500][0/157] Loss_D: 0.0079 Loss_G: 9.2920
[41/500][1/157] Loss_D: 0.0134 Loss_G: 8.5241
[41/500][2/157] Loss_D: 0.0156 Loss_G: 8.6983
[41/500][3/157] Loss_D: 0.0250 Loss_G: 8.1148
[41/500][4/157] Loss_D: 0.0160 Loss_G: 8.3324
[41/500][5/157] Loss_D: 0.0187 Loss_G: 7.6281
[41/500][6/157] Loss_D: 0.0191 Loss_G: 7.4707
[41/500][7/157] Loss_D: 0.0092 Loss_G: 8.3976
[41/500][8/157] Loss_D: 0.0118 Loss_G: 7.9800
[41/500][9/157] Loss_D: 0.0126 Loss_G: 7.3999
[41/500][10/157] Loss_D: 0.0165 Loss_G: 7.0854
[41/500][11/157] Loss_D: 0.0095 Loss_G: 7.6392
[41/500][12/157] Loss_D: 0.0079 Loss_G: 7.3862
[41/500][13/157] Loss_D: 0.0181 Loss_G: 7.3812
[41/500][14/157] Loss_D: 0.0168 Loss_G: 6.9518
[41/500][15/157] Loss_D: 0.0094 Loss_G: 7.8525
[41/500][16/157] Loss_D: 0.0165 Loss_G: 7.3024
[41/500][17/157] Loss_D: 0.0029 Loss_G: 8.4487
[41/500][18/157] Loss_D: 0.0169 Loss_G: 7.0449
[41/500][19/157] Loss_D: 0.0167 Loss_G: 7.1307
[41/500][20/157] Loss_D: 0.0255 Loss_G: 6.7970
[41/500][21/157] Loss_D: 0.0154 Loss_G: 6.9745
[41/500][22/157] Loss_D: 0.0110 Loss_G: 6.9925

As you can see there is a HUGE change happened to Generator loss(Loss_G).

Any idea why that happened ?

Any idea how to overcome such a problem ?

like image 461
AMendis Avatar asked Aug 24 '21 08:08

AMendis


People also ask

How many epochs does it take to train a GAN?

You have two options: Use 128 features instead of 196 in both the generator and the discriminator. This should drop training time to around 43 hours for 400 epochs.

Which of the following problems occurs while training a GAN?

However, there exist major challenges in training of GANs, i.e., mode collapse, non-convergence and instability, due to inappropriate design of network architecture, use of objective function and selection of optimization algorithm.

What causes mode collapse in GAN?

Mode collapse happens when the generator can only produce a single type of output or a small set of outputs. This may happen due to problems in training, such as the generator finds a type of data that is easily able to fool the discriminator and thus keeps generating that one type.

What is convergence failure in GAN?

Convergence failure happens when the generator and discriminator do not reach a balance during training.


1 Answers

The GAN training is inherently unstable because of simultaneous dynamic training of two competing models. Tried plotting the loss values from your question and the loss of discriminator and generator looks like below:

enter image description here

Looking at the loss and the generated images, we can say that the training fails to converge. This failure is due to not finding an equilibrium between the discriminator and the generator. we see that the loss for the discriminator is close to zero and, the loss of the generator rises and is unstable resulting in garbage images that the discriminator can easily identify as fake.

Discriminator classifies both the real data and the fake data from the generator. Discriminator loss is when it penalizes itself for misclassifying a real instance as fake, or a fake instance (created by the generator).

The generator loss is based on the discriminator’s classification – it gets rewarded if it successfully fools the discriminator, and gets penalized otherwise. GAN as zero-sum non-cooperative game, the win is either the discriminator's or the generator's. If one wins, the other loses. Convergence happens at Nash equilibrium which is when one's action doesn't affect the other. Read more into it here https://jonathan-hui.medium.com/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b and https://jonathan-hui.medium.com/gan-what-is-wrong-with-the-gan-cost-function-6f594162ce01 provides a deeper insight into the GAN challenges.

The convergence failure could also happen due to the mode collapse and Diminishing gradient. Also, In addition to the exploding gradients solution suggested by Nihal,

  1. Try implementing early stopping in the model based on the metrics Inception Score, Modified Inception Score, Frechet Inception Distance, Wasserstein distance (Taken from this paper https://arxiv.org/pdf/1802.03446.pdf) These measures help identify the model convergence and would automatically stop once the model converges.

  2. It is also shown that Spectral Normalization, a particular kind of normalization applied on the convolutional kernels, can greatly help the stability of the training. https://arxiv.org/pdf/1802.05957.pdf

  3. Making the training of the discriminator more difficult could help. Adding noise to both real images and images from generator helps increase the complexity of the discriminator training.

Increasing the iterations doesn't always improve the model. More training iterations, beyond some point of training stability may or may not result in higher quality images due to high-variance loss.And since GANs are relatively new, the research direction on challenges faced are still open and debatable.

like image 168
Sonia Samipillai Avatar answered Nov 11 '22 08:11

Sonia Samipillai