0

I would like to implement a VAE with a Dirichlet distributed latent space in Python.

Since the reparametrization trick does not work for the Dirichlet Distribution I would use Implicit Reparameterization Gradients for my endeavor.

Luckily, this is already implemented in Python as tfp.distributions.Dirichlet. Thus, I can't sample within the network without worrying about the probabilistic elements.

However, it is not clear to me how the loss function looks like in this case.

In the Gaussian case, the loss function comprises of the reconstruction loss and the KL-Divergence term (VAE (Keras)).

Instincevely, I would say that it should be same, since the objective hasn't changed.

Of course, the KL-Divergence term is different since the latent space is not Gaussian. According to DIRICHLET VARIATIONAL AUTOENCODER it should be: $$ KL(Q||P)=\sum log(\Gamma(\alpha_k))-\sum log(\Gamma(\hat{\alpha}_k))+\sum(\hat{\alpha}_k-\alpha_k)\cdot\psi(\hat{\alpha}_k) $$

My question:

  • Is my assumption about the Loss Function correct?

0 Answers0