Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generating new images with PyTorch

I am studying GANs I've completed the one course which gave me an example of a program that generates images based on examples inputed.

The example can be found here:

https://github.com/davidsonmizael/gan

So I decided to use that to generate new images based on a dataset of frontal photos of faces, but I am not having any success. Differently from the example above, the code only generates noise, while the input has actual images.

Actually I don't have any clue about what should I change to make the code point to the right direction and learn from images. I haven't change a single value on the code provided in the example, yet it does not work.

If anyone can help me understand this and point me to the right direction would be very helpful. Thanks in advance.

My Discriminator:

class D(nn.Module):

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(3, 64, 4, 2, 1, bias = False),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(64, 128, 4, 2, 1, bias = False),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(128, 256, 4, 2, 1, bias = False),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(256, 512, 4, 2, 1, bias = False),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2, inplace = True),
                nn.Conv2d(512, 1, 4, 1, 0, bias = False),
                nn.Sigmoid()
                )

    def forward(self, input):
        return self.main(input).view(-1)

My Generator:

class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
                nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
                nn.BatchNorm2d(512),
                nn.ReLU(True),
                nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU(True),
                nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
                nn.BatchNorm2d(128),
                nn.ReLU(True),
                nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
                nn.BatchNorm2d(64),
                nn.ReLU(True),
                nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
                nn.Tanh()
                )

    def forward(self, input):
        return self.main(input)

My function to start the weights:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

Full code can be seen here:

https://github.com/davidsonmizael/criminal-gan

Noise generated on epoch number 25: Noise generated on epoch number 25

Input with real images: Input with real images.

like image 632
davis Avatar asked Nov 13 '17 23:11

davis


2 Answers

The code from your example (https://github.com/davidsonmizael/gan) gave me the same noise as you show. The loss of the generator decreased way too quickly.

There were a few things buggy, I'm not even sure anymore what - but I guess it's easy to figure out the differences yourself. For a comparison, also have a look at this tutorial: GANs in 50 lines of PyTorch

.... same as your code
print("# Starting generator and descriminator...")
netG = G()
netG.apply(weights_init)

netD = D()
netD.apply(weights_init)

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()

#training the DCGANs
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

epochs = 25

timeElapsed = []
for epoch in range(epochs):
    print("# Starting epoch [%d/%d]..." % (epoch, epochs))
    for i, data in enumerate(dataloader, 0):
        start = time.time()
        time.clock()  

        #updates the weights of the discriminator nn
        netD.zero_grad()

        #trains the discriminator with a real image
        real, _ = data

        if torch.cuda.is_available():
            inputs = Variable(real.cuda()).cuda()
            target = Variable(torch.ones(inputs.size()[0]).cuda()).cuda()
        else:
            inputs = Variable(real)
            target = Variable(torch.ones(inputs.size()[0]))

        output = netD(inputs)
        errD_real = criterion(output, target)
        errD_real.backward() #retain_graph=True

        #trains the discriminator with a fake image
        if torch.cuda.is_available():
            D_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1).cuda()).cuda()
            target = Variable(torch.zeros(inputs.size()[0]).cuda()).cuda()
        else:
            D_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1))
            target = Variable(torch.zeros(inputs.size()[0]))
        D_fake = netG(D_noise).detach()
        D_fake_ouput = netD(D_fake)
        errD_fake = criterion(D_fake_ouput, target)
        errD_fake.backward()

        # NOT:backpropagating the total error
        # errD = errD_real + errD_fake

        optimizerD.step()

    #for i, data in enumerate(dataloader, 0):

        #updates the weights of the generator nn
        netG.zero_grad()

        if torch.cuda.is_available():
            G_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1).cuda()).cuda()
            target = Variable(torch.ones(inputs.size()[0]).cuda()).cuda()
        else:
            G_noise = Variable(torch.randn(inputs.size()[0], 100, 1, 1))
            target = Variable(torch.ones(inputs.size()[0]))

        fake = netG(G_noise)
        G_output = netD(fake)
        errG  = criterion(G_output, target)

        #backpropagating the error
        errG.backward()
        optimizerG.step()


        if i % 50 == 0:
            #prints the losses and save the real images and the generated images
            print("# Progress: ")
            print("[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f" % (epoch, epochs, i, len(dataloader), errD_real.data[0], errG.data[0]))

            #calculates the remaining time by taking the avg seconds that every loop
            #and multiplying by the loops that still need to run
            timeElapsed.append(time.time() - start)
            avg_time = (sum(timeElapsed) / float(len(timeElapsed)))
            all_dtl = (epoch * len(dataloader)) + i
            rem_dtl = (len(dataloader) - i) + ((epochs - epoch) * len(dataloader))
            remaining =  (all_dtl - rem_dtl) * avg_time
            print("# Estimated remaining time: %s" % (time.strftime("%H:%M:%S", time.gmtime(remaining))))

        if i % 100 == 0:
            vutils.save_image(real, "%s/real_samples.png" % "./results", normalize = True)
            vutils.save_image(fake.data, "%s/fake_samples_epoch_%03d.png" % ("./results", epoch), normalize = True)

print ("# Finished.")

Result after 25 epochs (batchsize 256) on CIFAR-10: enter image description here

like image 52
Forcetti Avatar answered Nov 03 '22 13:11

Forcetti


GAN Training is not very fast. I'm assuming you are not using a pre-trained model, but learning from scratch. On epoch 25 it is quite normal to not see any meaningful patterns in the samples. I realize that the github project shows you something cool after 25 epochs, but that also depends on the size of the dataset. CIFAR-10 (the one that was used on the github page) has 60000 images. 25 epochs means the net has seen all of them 25 times.

I do not know which dataset you are using, but if it is smaller it might take more epochs until you see results, because the net gets to see less images in total. If the images in your dataset have a higher resolution, it might also take longer.

You should check again after at least a few hundred, if not a few thousand epochs.


E.g. on the frontal face photo dataset after 25 epochs: enter image description here

And after 50 epochs: enter image description here

like image 1
RunOrVeith Avatar answered Nov 03 '22 14:11

RunOrVeith