8

From my understanding of VAE's, there's a step during training in the middle where, after the encoder produces a mean and standard deviation, random samples are drawn from the given learned distribution to create the encoded vector that the decoder works to decode. I understand how one uses the KL divergence to force the learned distribution to be approximately the standard Gaussian, but I don't understand how the reconstruction loss can be back propagated past this sampling step. Random sampling is not a differentiable operation, so how can the gradients propagate past it? Is my understanding of VAE's wrong?

enumaris
  • 1,075
  • 2
  • 9
  • 19
  • 5
    Does this answer your question? [How does the reparameterization trick for VAEs work and why is it important?](https://stats.stackexchange.com/questions/199605/how-does-the-reparameterization-trick-for-vaes-work-and-why-is-it-important) – David Dao Jun 17 '20 at 22:51

1 Answers1

14

The reparameterization trick.

$$x = \text{sample}(\mathcal{N}(\mu, \sigma^2))$$

is not backpropable wrt $\mu$ or $\sigma$. However, we can rewrite this as:

$$x = \mu + \sigma\ \text{sample}( \mathcal{N}(0, 1))$$

which is clearly equivalent and backpropable.

shimao
  • 22,706
  • 2
  • 42
  • 81
  • 1
    Does this mean that we can't build autoencoders that use a different distribution that can't be reparameterized this way? – enumaris Apr 26 '18 at 15:52
  • 1
    @enumaris most distributions can be reparameterized. For example, you can use a categorical latent space using the gumbel softmax trick. – shimao Apr 26 '18 at 15:55
  • 2
    But in theory the normal distribution is all you'll ever need, since a sufficiently powerful function approximator can always map the normal distribution to any arbitrary probability distribution. – shimao Apr 26 '18 at 15:56
  • Hey, thank you for this answer - this really helped me! But I have a question: Why is $\sigma$ squared when using as an argument for the function but not squared when used outside? – blue-phoenix Aug 09 '18 at 07:37
  • 1
    @blue-phoenix Scaling a random variable by a factor of $k$ scales the variance by a factor of $k^2$ – shimao Aug 09 '18 at 07:40
  • So when I have a normal() implementation which takes the standard deviation as argument, not variance (is not backpropable, therefore I'm asking). Than I need to use the square root of $\sigma$ when using as a factor outside? – blue-phoenix Aug 09 '18 at 07:44
  • @blue-phoenix by convention, the two arguments to $\mathcal{N}$ are mean and variance respectively. $\sigma$ is the typical letter for standard deviation. – shimao Aug 09 '18 at 07:45
  • Concretely I'm talking about the pytorch implementation `normal(mean, std, out=None)` - `The :attr: 'std' is a tensor with the standard deviation of each output element's normal distribution.` So I assumed this function takes $\mu$, $\sigma$ instead of $\mu$, $\sigma^2$. So I'm wondering if the formula above changes then to: $x = \mu + \sqrt{\sigma} * \text{sample}( \mathcal{N}(0, 1))$? Or doesn't it make any difference because variance resp. standard deviation equals 1? – blue-phoenix Aug 09 '18 at 07:55
  • 1
    @blue-phoenix there's no need to add the square root. The fact that a particular implementation of the normal distribution is parameterized by the standard deviation rather than the variance doesn't change any of the math. – shimao Aug 09 '18 at 07:58
  • @shimao: excuse me, why is it not backpropagatable with respect to mu and or std? – Hossein Oct 01 '19 at 10:57