When not using KL divergence term, the VAE reconstructs mnist images almost perfectly but fails to generate new ones properly when provided with random noise.
When using KL divergence term, the VAE gives the same weird output both when reconstructing and generating images.
Here's the pytorch code for the loss function:
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=True)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return (BCE+KLD)
recon_x is the reconstructed image, x is the original_image, mu is the mean vector while logvar is the vector containing the log of variance.
What is going wrong here? Thanks in advance :)
A possible reason is the numerical unbalance between the two losses, with your BCE
loss computed as an average over the batch (c.f. size_average=True
) while the KLD
one is summed.
Yes, try out with different weight factor for the KLD loss term. Weighing down the KLD loss term resolves the same reconstruction output issue in the CelebA dataset (http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With