Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

KL Divergence for two probability distributions in PyTorch

I have two probability distributions. How should I find the KL-divergence between them in PyTorch? The regular cross entropy only accepts integer labels.

like image 223
Mojtaba Komeili Avatar asked Apr 17 '18 19:04

Mojtaba Komeili


People also ask

What is the KL divergence between two equal distributions?

The Kullback-Leibler Divergence score, or KL divergence score, quantifies how much one probability distribution differs from another probability distribution. The KL divergence between two distributions Q and P is often stated using the following notation: KL(P || Q)

What is Log_prob in Pytorch?

log_prob (value)[source] Returns the log of the probability density/mass function evaluated at value . value (Tensor) – property mean. Returns the mean of the distribution.

Is the Kullback-Leibler divergence a distance metric?

Although the KL divergence measures the “distance” between two distri- butions, it is not a distance measure. This is because that the KL divergence is not a metric measure. It is not symmetric: the KL from p(x) to q(x) is generally not the same as the KL from q(x) to p(x).

Is KL divergence same as cross-entropy?

KL divergence is the relative entropy or difference between cross entropy and entropy or some distance between actual probability distribution and predicted probability distribution. It is equal to 0 when the predicted probability distribution is the same as the actual probability distribution.


2 Answers

Yes, PyTorch has a method named kl_div under torch.nn.functional to directly compute KL-devergence between tensors. Suppose you have tensor a and b of same shape. You can use the following code:

import torch.nn.functional as F
out = F.kl_div(a, b)

For more details, see the above method documentation.

like image 67
jdhao Avatar answered Sep 27 '22 19:09

jdhao


function kl_div is not the same as wiki's explanation.

I use the following:

# this is the same example in wiki
P = torch.Tensor([0.36, 0.48, 0.16])
Q = torch.Tensor([0.333, 0.333, 0.333])

(P * (P / Q).log()).sum()
# tensor(0.0863), 10.2 µs ± 508

F.kl_div(Q.log(), P, None, None, 'sum')
# tensor(0.0863), 14.1 µs ± 408 ns

compare to kl_div, even faster

like image 43
hantian_pang Avatar answered Sep 27 '22 18:09

hantian_pang