Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

KL Divergence of two torch.distribution.Distribution objects

I'm trying to determine how to compute KL Divergence of two torch.distribution.Distribution objects. I couldn't find a function to do that so far. Here is what I've tried:

import torch as t
from torch import distributions as tdist
import torch.nn.functional as F

def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
    """Compute the KL divergence between two distributions."""
    return F.kl_div(x, y)  

a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)

print(kl_divergence(a, b))  # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal
like image 637
ikamen Avatar asked Jan 23 '26 15:01

ikamen


1 Answers

torch.nn.functional.kl_div is computing the KL-divergence loss. The KL-divergence between two distributions can be computed using torch.distributions.kl.kl_divergence.

like image 151
jodag Avatar answered Jan 26 '26 23:01

jodag



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!