Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the meaning of the implementation of the KL divergence in Keras?

I'm a little confused about how the KL divergence is applied, specifically in Keras, but I think the question is general to deep learning applications. In keras, the KL loss function is defined like this:

def kullback_leibler_divergence(y_true, y_pred):
    y_true = K.clip(y_true, K.epsilon(), 1)
    y_pred = K.clip(y_pred, K.epsilon(), 1)
    return K.sum(y_true * K.log(y_true / y_pred), axis=-1)

In my model, y_true and y_pred are matrices; each row of y_true a one-hot encoding for one training example, and each row of y_pred the output of the model (a probability distribution) for that example.

I can run this KL divergence calculation on any given pair of rows from y_true and y_pred and get the expected result. The mean of these KL divergence results over the rows matches the loss reported by Keras in the training history. But that aggregation - running KL divergence on each row and taking the mean - doesn't happen within loss function. In contrast, I understand MAE or MSE to aggregate across the examples:

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

For the KL divergence, it's not totally obvious to me that taking the mean across the examples is the right thing to do. I guess the idea is that the examples are random samples from the true distribution, so they should appear in proportion to their probability. But that seems to make a pretty strong assumption about how the training data was collected. I haven't really seen this aspect (aggregating across samples from a dataset) addressed in the online treatments of the KL divergence; I just see a lot of redefinition of the basic formula.

So my questions are:

  1. Is this interpretation of what Keras is doing to come up with the KL divergence loss (i.e. averaging over the KL divergence of the rows) correct?

  2. Why is this the right thing to do?

  3. From an implementation perspective, why doesn't the definition of the loss function in Keras do the aggregation over the rows the way MAE or MSE does?

like image 692
mechner Avatar asked Jun 05 '17 19:06

mechner


People also ask

What is the purpose of KL divergence?

The Kullback-Leibler Divergence score, or KL divergence score, quantifies how much one probability distribution differs from another probability distribution.

How do you interpret KL divergence?

Intuitively this measures the how much a given arbitrary distribution is away from the true distribution. If two distributions perfectly match, D_{KL} (p||q) = 0 otherwise it can take values between 0 and ∞. Lower the KL divergence value, the better we have matched the true distribution with our approximation.

What does it mean when KL divergence is 0?

Therefore, the K-L divergence is zero when the two distributions are equal. The K-L divergence is positive if the distributions are different.

Why KL divergence is used in VAE?

This is where we use KL divergence as a measure of a difference between two probability distributions. The VAE objective function thus includes this KL divergence term that needs to be minimized.


1 Answers

Kullback-Leibler divergence is a measure of similarity between two probability distributions. The KL divergence implemented in Keras assumes two discrete probability distributions (hence the sum).

The exact format of your KL loss function depends on the underlying probability distributions. A common usecase is that the neural network models the parameters of a probability distribution P (eg a Gaussian), and the KL divergence is then used in the loss function to determine the similarity between the modelled distribution and some other, known distribution (potentially Gaussian as well). E.g. a network outputs two vectors mu and sigma^2. Mu forms the mean of a Gaussian distribution P while sigma^2 is the diagonal of the covariance matrix Sigma. A possible loss function is then the KL divergence between the Gaussian P described by mu and Sigma, and a unit Gaussian N(0, I). The exact format of the KL divergence in that case can be derived analytically, yielding a custom keras loss function that is not at all equal to the KL divergence implemented in Keras.

In the original paper that introduces Variational Auto-Encoders, the loss function is summed over the samples in the minibatch and then multiplied by a factor (N/M), where N is the size of the entire dataset, and M is the size of the minibatch. See equations 8 and 10 in https://arxiv.org/abs/1312.6114.

like image 72
datwelk Avatar answered Nov 15 '22 11:11

datwelk