Summary: My aim is to create a (probabilistic) neural network for classification that learns the distribution of its class probabilities. The Dirichlet distribution seems to be choice. I am familiar with the reparametrization trick and I would like to apply it here. I thought I found a way to generate gamma distributed random variables (which are needed for the Dirichlet distribution) within the network (detailed explanation below).
My questions are:
- Does the sampling process for a gamma distribution in Dirichlet Variatiational Autoencoder actually work for $\forall\alpha > 0$ or have I read it wrong and it does only work for $\alpha \le 1$?
- If it does only work for $\alpha \le 1$, is there an alternative to the Dirichlet Distribution (i.e. Mixture of Gaussians as continuous approximation of the discrete multinomial distribution) in my case?
I already read two posts that touch the issue of the reparametrization trick for non-gaussian distributions. The first one made me think that my issue could not easily be resolved (Reparameterization trick for gamma distribution), the other one (Reparametrization trick with non-Gaussian distributions?) made me a little more optimistic. I read the paper mentioned in the post (Dirichlet Variatiational Autoencoder). It says:
- Approximation with inverse Gamma CDF. A previous work Knowles (2015) suggested that, if $X ∼ Gamma(\alpha,\beta)$, and if $F(x; \alpha,\beta)$ is a CDF of the random variable $X$, the inverse CDF can be approximated as $F^{−1}(u; \alpha,\beta) \approx \beta^{−1}(u\alpha \Gamma(\alpha))^{1/\alpha}$ for $u$ a unit-uniform random variable.
When I compared the approximation to the rgamma
function ($\alpha$ is varied, $\beta = 1$) in R, I saw that it only works relatively well when $\alpha \le 1$.
When reading the original source of the approximation this was confirmed:
- For $a < 1$ and $(1−0.94z)\;\log(a) < −0.42$ we use $F_{a,b}(z) ≈ (zaΓ(a))^{1/a}/b$.
Here is the R Code for the visualization above.
library(tidyverse)
alpha <- c(0.1, 0.25, 0.5, 1, 2, 4, 10)
beta <- 1
n <- 100000
u <- runif(n = n)
values_actual <-
map_df(c(0.1, 0.25, 0.5, 1, 2, 4, 10),
function(alpha) tibble(data = rgamma(n = n, shape = alpha, rate = beta),
alpha = alpha)) %>%
mutate(type = "actual")
values_approximated <-
map_df(c(0.1, 0.25, 0.5, 1, 2, 4, 10),
function(alpha) tibble(data = (u*alpha*gamma(alpha))^(1/alpha),
alpha = alpha)) %>%
mutate(type = "approximation")
rbind(values_actual, values_approximated) %>%
mutate(type = as.factor(type)) %>%
ggplot(aes(x=data))+
geom_histogram()+
facet_grid(rows = vars(type),
cols = vars(alpha))+
theme_classic()+
labs(x="")