From my understanding of VAE's, there's a step during training in the middle where, after the encoder produces a mean and standard deviation, random samples are drawn from the given learned distribution to create the encoded vector that the decoder works to decode. I understand how one uses the KL divergence to force the learned distribution to be approximately the standard Gaussian, but I don't understand how the reconstruction loss can be back propagated past this sampling step. Random sampling is not a differentiable operation, so how can the gradients propagate past it? Is my understanding of VAE's wrong?
Asked
Active
Viewed 2,663 times
8
-
5Does this answer your question? [How does the reparameterization trick for VAEs work and why is it important?](https://stats.stackexchange.com/questions/199605/how-does-the-reparameterization-trick-for-vaes-work-and-why-is-it-important) – David Dao Jun 17 '20 at 22:51
1 Answers
14
The reparameterization trick.
$$x = \text{sample}(\mathcal{N}(\mu, \sigma^2))$$
is not backpropable wrt $\mu$ or $\sigma$. However, we can rewrite this as:
$$x = \mu + \sigma\ \text{sample}( \mathcal{N}(0, 1))$$
which is clearly equivalent and backpropable.

shimao
- 22,706
- 2
- 42
- 81
-
1Does this mean that we can't build autoencoders that use a different distribution that can't be reparameterized this way? – enumaris Apr 26 '18 at 15:52
-
1@enumaris most distributions can be reparameterized. For example, you can use a categorical latent space using the gumbel softmax trick. – shimao Apr 26 '18 at 15:55
-
2But in theory the normal distribution is all you'll ever need, since a sufficiently powerful function approximator can always map the normal distribution to any arbitrary probability distribution. – shimao Apr 26 '18 at 15:56
-
Hey, thank you for this answer - this really helped me! But I have a question: Why is $\sigma$ squared when using as an argument for the function but not squared when used outside? – blue-phoenix Aug 09 '18 at 07:37
-
1@blue-phoenix Scaling a random variable by a factor of $k$ scales the variance by a factor of $k^2$ – shimao Aug 09 '18 at 07:40
-
So when I have a normal() implementation which takes the standard deviation as argument, not variance (is not backpropable, therefore I'm asking). Than I need to use the square root of $\sigma$ when using as a factor outside? – blue-phoenix Aug 09 '18 at 07:44
-
@blue-phoenix by convention, the two arguments to $\mathcal{N}$ are mean and variance respectively. $\sigma$ is the typical letter for standard deviation. – shimao Aug 09 '18 at 07:45
-
Concretely I'm talking about the pytorch implementation `normal(mean, std, out=None)` - `The :attr: 'std' is a tensor with the standard deviation of each output element's normal distribution.` So I assumed this function takes $\mu$, $\sigma$ instead of $\mu$, $\sigma^2$. So I'm wondering if the formula above changes then to: $x = \mu + \sqrt{\sigma} * \text{sample}( \mathcal{N}(0, 1))$? Or doesn't it make any difference because variance resp. standard deviation equals 1? – blue-phoenix Aug 09 '18 at 07:55
-
1@blue-phoenix there's no need to add the square root. The fact that a particular implementation of the normal distribution is parameterized by the standard deviation rather than the variance doesn't change any of the math. – shimao Aug 09 '18 at 07:58
-
@shimao: excuse me, why is it not backpropagatable with respect to mu and or std? – Hossein Oct 01 '19 at 10:57