pythonpytorchloss-functionautoencoder

How is KL-divergence in pytorch code related to the formula?


In VAE tutorial, kl-divergence of two Normal Distributions is defined by: enter image description here

And in many code, such as here, hereand here, the code is implemented as:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

or

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

How are they related? why there is not any "tr" or ".transpose()" in code?


Solution

  • The expressions in the code you posted assume X is an uncorrelated multi-variate Gaussian random variable. This is apparent by the lack of cross terms in the determinant of the covariance matrix. Therefore the mean vector and covariance matrix take the forms

    enter image description here

    Using this we can quickly derive the following equivalent representations for the components of the original expression

    enter image description here

    Substituting these back into the original expression gives

    enter image description here