pytorchpytorch-distributions

KL Divergence of two torch.distribution.Distribution objects


I'm trying to determine how to compute KL Divergence of two torch.distribution.Distribution objects. I couldn't find a function to do that so far. Here is what I've tried:

import torch as t
from torch import distributions as tdist
import torch.nn.functional as F

def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
    """Compute the KL divergence between two distributions."""
    return F.kl_div(x, y)  

a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)

print(kl_divergence(a, b))  # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal

Solution

  • torch.nn.functional.kl_div is computing the KL-divergence loss. The KL-divergence between two distributions can be computed using torch.distributions.kl.kl_divergence.