13

To my (very modest) understand of variational inference, one tries to approximate an unknown distribution $p$ by finding a distribution $q$ that optimises the following:

$$KL (p||q) = \sum\limits_{x} p(x)log \frac {p(x)}{q(x)}$$

Whenever I invest time into understanding variational inference I keep hitting this formula and can't help but feel like I'm missing the point. It seems like I need to know $p$ in order to calculate $KL(p||q)$. But the whole point was I did not know this distribution $p$.

It's this exact point that's been bugging me every time I try to read up something variational. What am I missing?

EDIT:

I'll add a few extra comments here as a result of the answer of @wij, I'll attempt to be more precise.

In the cases that I am interested in, it indeed seems perfectly reasonable to consider that the following holds;

$$p(\theta | D) = \frac{p(D|\theta)p(\theta)}{p(D)} \propto p(D|\theta)p(\theta)$$

In this case I could know what $p$ should proportionally look like because I will have made a model choice for $p(D|\theta)$ and $p(\theta)$. Would I then be correct in saying that I then need to pick a family distribution $q$ [lets say gaussian] such that now I can estimate $KL(p(\theta|D) || q)$. It feels like in this case I am trying to fit a gaussian that is close to the non-normalized $p(D|\theta)p(\theta)$. Is this correct?

If so, it feels like I am assuming that my posterior is a normal distribution and I merely try to find likely values for this distribution with regards to the $KL$ divergence.

Vincent Warmerdam
  • 1,129
  • 1
  • 9
  • 10

1 Answers1

9

I have a feeling that you treat $p$ as a completely unknown object. I do not think this is the case. This is probably what you missed.

Say we observe $Y = \{y_i\}_{i=1}^n$ (i.i.d.) and we want to infer $p(x|Y)$ where we assume that $p(y|x)$ and $p(x)$ for $x\in\mathbb{R}^d$ are specified by the model. By Bayes' rule,

$$p(x|Y) = \frac{p(x)}{p(Y)}p(Y|x) = \frac{p(x)}{p(Y)}\prod_{i=1}^n p(y_i|x).$$

The first observation is that we know something about the posterior distribution $p(x|Y)$. It is given as above. Typically, we just do not know its normalizer $p(Y)$. If the likelihood $p(y|x)$ is very complicated, then we end up having some complicated distribution $p(x|Y)$.

The second thing that makes it possible to do variational inference is that there is a constraint on the form that $q$ can take. Without any constraint, $\arg \min_q KL(p||q)$ would be $p$ which is usually intractable. Typically, $q$ is assumed to live in a chosen subset of the exponential family. For example, this might be the family of fully factorized Gaussian distributions i.e., $q \in \mathcal{Q} = \{\prod_{i=1}^d q_i(x_i) \mid \text{each } q_i \text{ is a one-dimensional Gaussian}\}$. It turns out that if this is your constraint set, then each component of $q$ is given by

$$q_i \propto \exp( \mathbb{E}_{\prod_{j\neq i} q_j} \log p(x, Y) ), $$

where $p(x, Y) = p(x) \prod_{i=1}^n p(y_i|x).$ The exact formula does not matter much. The point is the approximate $q$ can be found by relying on the knowledge of the true $p$, and the assumption on the form that the approximate $q$ should take.

Update

The following is to answer the updated part in the question. I just realized that I have been thinking about $KL(q||p(x|Y))$. I will always use $p$ for the true quantity, and $q$ for an approximate one. In variational inference or variational Bayes, $q$ is given by

$$q = \arg \min_{q \in \mathcal{Q}} KL(q\, ||\, p(x|Y)).$$

With the constraint set $\mathcal{Q}$ as above, the solution is the one given previously. Now if you are thinking about

$$q = \arg \min_{q \in \mathcal{Q}} KL( p(x|Y) \, || \, q),$$

for $\mathcal{Q}$ defined to be a subset of the exponential family, then this inference is called expectation propagation (EP). The solution for $q$ in this case is the one such that its moments match that of $p(x|Y)$.

Either way, you are right in saying that essentially you try to approximate the true posterior distribution in the KL sense by a distribution $q$ constrained to take some form.

wij
  • 1,893
  • 11
  • 18