I'm a little confused about how the KL divergence is applied, specifically in Keras, but I think the question is general to deep learning applications. In keras, the KL loss function is defined like this:
def kullback_leibler_divergence(y_true, y_pred):
y_true = K.clip(y_true, K.epsilon(), 1)
y_pred = K.clip(y_pred, K.epsilon(), 1)
return K.sum(y_true * K.log(y_true / y_pred), axis=-1)
In my model, y_true
and y_pred
are matrices; each row of y_true
a one-hot encoding for one training example, and each row of y_pred
the output of the model (a probability distribution) for that example.
I can run this KL divergence calculation on any given pair of rows from y_true
and y_pred
and get the expected result. The mean of these KL divergence results over the rows matches the loss reported by Keras in the training history. But that aggregation - running KL divergence on each row and taking the mean - doesn't happen within loss function. In contrast, I understand MAE or MSE to aggregate across the examples:
def mean_squared_error(y_true, y_pred):
return K.mean(K.square(y_pred - y_true), axis=-1)
For the KL divergence, it's not totally obvious to me that taking the mean across the examples is the right thing to do. I guess the idea is that the examples are random samples from the true distribution, so they should appear in proportion to their probability. But that seems to make a pretty strong assumption about how the training data was collected. I haven't really seen this aspect (aggregating across samples from a dataset) addressed in the online treatments of the KL divergence; I just see a lot of redefinition of the basic formula.
So my questions are:
Is this interpretation of what Keras is doing to come up with the KL divergence loss (i.e. averaging over the KL divergence of the rows) correct?
Why is this the right thing to do?
From an implementation perspective, why doesn't the definition of the loss function in Keras do the aggregation over the rows the way MAE or MSE does?
The Kullback-Leibler Divergence score, or KL divergence score, quantifies how much one probability distribution differs from another probability distribution.
Intuitively this measures the how much a given arbitrary distribution is away from the true distribution. If two distributions perfectly match, D_{KL} (p||q) = 0 otherwise it can take values between 0 and ∞. Lower the KL divergence value, the better we have matched the true distribution with our approximation.
Therefore, the K-L divergence is zero when the two distributions are equal. The K-L divergence is positive if the distributions are different.
This is where we use KL divergence as a measure of a difference between two probability distributions. The VAE objective function thus includes this KL divergence term that needs to be minimized.
Kullback-Leibler divergence is a measure of similarity between two probability distributions. The KL divergence implemented in Keras assumes two discrete probability distributions (hence the sum).
The exact format of your KL loss function depends on the underlying probability distributions. A common usecase is that the neural network models the parameters of a probability distribution P (eg a Gaussian), and the KL divergence is then used in the loss function to determine the similarity between the modelled distribution and some other, known distribution (potentially Gaussian as well). E.g. a network outputs two vectors mu and sigma^2. Mu forms the mean of a Gaussian distribution P while sigma^2 is the diagonal of the covariance matrix Sigma. A possible loss function is then the KL divergence between the Gaussian P described by mu and Sigma, and a unit Gaussian N(0, I). The exact format of the KL divergence in that case can be derived analytically, yielding a custom keras loss function that is not at all equal to the KL divergence implemented in Keras.
In the original paper that introduces Variational Auto-Encoders, the loss function is summed over the samples in the minibatch and then multiplied by a factor (N/M), where N is the size of the entire dataset, and M is the size of the minibatch. See equations 8 and 10 in https://arxiv.org/abs/1312.6114.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With