Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Variational Autoencoder gives same output image for every input mnist image when using KL divergence

When not using KL divergence term, the VAE reconstructs mnist images almost perfectly but fails to generate new ones properly when provided with random noise.

When using KL divergence term, the VAE gives the same weird output both when reconstructing and generating images.

enter image description here

Here's the pytorch code for the loss function:

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=True)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())    
    return (BCE+KLD)

recon_x is the reconstructed image, x is the original_image, mu is the mean vector while logvar is the vector containing the log of variance.

What is going wrong here? Thanks in advance :)

like image 469
Cracin Avatar asked May 30 '18 14:05

Cracin


2 Answers

A possible reason is the numerical unbalance between the two losses, with your BCE loss computed as an average over the batch (c.f. size_average=True) while the KLD one is summed.

like image 118
benjaminplanche Avatar answered Sep 25 '22 01:09

benjaminplanche


Yes, try out with different weight factor for the KLD loss term. Weighing down the KLD loss term resolves the same reconstruction output issue in the CelebA dataset (http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

like image 26
Md. Alimoor Reza Avatar answered Sep 22 '22 01:09

Md. Alimoor Reza