I'm looking at the following implementation of a VAE: https://github.com/jmtomczak/vae_vpflows/blob/master/models/VAE.py
KL divergence is implemented as:
# KL
log_p_z = log_Normal_standard(z_q, dim=1)
log_q_z = log_Normal_diag(z_q, z_q_mean, z_q_logvar, dim=1)
KL = -(log_p_z - log_q_z)
z_q is a batch of samples from the latent space p(z|x) and z_q_mean and z_q_logvar are the predicted means and log variances from which the sample is drawn. log_Normal_standard and log_Normal_diag are implemented as follows:
def log_Normal_diag(x, mean, log_var, average=False, dim=None):
log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
if average:
return torch.mean( log_normal, dim )
else:
return torch.sum( log_normal, dim )
def log_Normal_standard(x, average=False, dim=None):
log_normal = -0.5 * torch.pow( x , 2 )
if average:
return torch.mean( log_normal, dim )
else:
return torch.sum( log_normal, dim )
I'm unfamiliar with this calculation of KL divergence for lognormal distributions and I can't find any supplementary material that matches this formulation.
Can anyone point me to equations that match this formulation?