class: center, middle
# Bayesian Methods for Machine Learning .small-vspace[ ] ### Lecture 5 - Variational inference
.bold[Simon Leglaive]
.tiny[CentraleSupélec] --- class: middle, center ## Approximate inference --- In Bayesian methods, it's all about posterior computation: $$p(\mathbf{z}|\mathbf{x} ; \theta) = \frac{p(\mathbf{x}|\mathbf{z} ; \theta) p(\mathbf{z} ; \theta) }{ p(\mathbf{x} ; \theta)} = \frac{p(\mathbf{x}|\mathbf{z} ; \theta) p(\mathbf{z} ; \theta)}{\int p(\mathbf{x}|\mathbf{z} ; \theta)p(\mathbf{z} ; \theta) d\mathbf{z}}$$ - Easy to compute for conjugate prior. - Hard for many models were conjugacy does not hold. -- count: false For example: $$ p(\mathbf{x}|\mathbf{z} ; \theta) = \mathcal{N}\left(\mathbf{x}; \boldsymbol{\mu}\_\theta(\mathbf{z}), \boldsymbol{\Sigma}\_\theta(\mathbf{z})\right),$$ where $\boldsymbol{\mu}\_\theta$ and $\boldsymbol{\Sigma}\_\theta$ are **neural networks**. The marginal likelihood $\displaystyle p(\mathbf{x} ; \theta) = \int p(\mathbf{x}|\mathbf{z} ; \theta)p(\mathbf{z} ; \theta) d\mathbf{z} = \int\mathcal{N}\left(\mathbf{x}; \boldsymbol{\mu}\_\theta(\mathbf{z}), \boldsymbol{\Sigma}\_\theta(\mathbf{z})\right)p(\mathbf{z}; \theta) d\mathbf{z}$ is **intractable** due to the non-linearities. .footnote[For simplicity of notations, we denote the parameters of the likelihood and prior by $\theta$ even though they are generally different.] --- class: middle We need **approximate inference** techniques when exact posterior computation is infeasible: - **Variational inference** (focus of today) - Markov Chain Monte Carlo --- class: middle, center ## Variational inference --- class: middle The main idea of **variational inference** is to cast inference as an **optimization problem**. We want to find a **variational distribution** $q(\mathbf{z}) \in \mathcal{F}$ which approximates the true intractable posterior $p(\mathbf{z}|\mathbf{x})$. We need to define: - a **measure of fit** between $q(\mathbf{z})$ and $p(\mathbf{z}|\mathbf{x} ; \theta)$, to be minimized, - a **variational family** $\mathcal{F}$, which corresponds to the set of acceptable solutions for the variational distribution. --- class: middle ### The KL divergence $$D\_{\text{KL}}(q \parallel p) = \mathbb{E}\_{q}[ \ln q - \ln p].$$ The KL divergence has the following properties: - $D\_{\text{KL}}(q \parallel p) \ge 0$ - $D\_{\text{KL}}(q \parallel p) = 0$ if and only if $q = p$ - $D\_{\text{KL}}(q \parallel p) \neq D\_{\text{KL}}(p \parallel q)$ Why do we choose $D\_{\text{KL}}(q \parallel p)$ and not $D\_{\text{KL}}(p \parallel q)$? --- exclude: true class: middle - Why do we choose $D\_{\text{KL}}(q \parallel p)$ and not $D\_{\text{KL}}(p \parallel q)$? .small-vspace[ ] $D\_{\text{KL}}(q \parallel p)$ involves an expectation w.r.t $q$, while for $D\_{\text{KL}}(p \parallel q)$ the expectation is taken w.r.t $p$, which is intractable (when it is the posterior). - How does this choice influence the approximation? --- exclude: true $$ D\_{\text{KL}}(q \parallel p) = \mathbb{E}\_{q}[ \ln \big(q / p \big)] $$ This is the **reverse KL**, which is large when $p$ is close to zero and $q$ is not. This form penalizes distributions $q$ that put probability mass where $p$ is small. However it is ok if $q$ is close to zero while $p$ is not. As a consequence, it may underestimate the support of $p$, i.e. $q$ may concentrate on a single mode of $p$. .center.width-50[![](images/KL_1.png)] --- exclude: true $$D\_{\text{KL}}(p \parallel q) = \mathbb{E}\_{p}[ \ln \big(p / q \big)]$$ The **forward KL** is large when $q$ is close to zero and $p$ is not. This form penalizes distributions $q$ that "would not sufficiently cover" $p$. However it is ok if $q$ has probability mass where $p$ is close to zero. As a consequence, it may overestimate the support of $p$, i.e. $q$ may have probability mass on regions where $p$ does not. .center.width-30[![](images/KL_2.png)] --- exclude: true class: middle, black-slide .center[
] .credit[Image credit: https://twitter.com/ari_seff/status/1303741288911638530] --- class: middle ### The ELBO We take the (reverse) **KL divergence** as a **measure of fit** between the variational distribution and the intractable posterior: $$ \begin{aligned} D\_{\text{KL}}(q(\mathbf{z}) \parallel p(\mathbf{z}|\mathbf{x} ; \theta)) & = \mathbb{E}\_{q(\mathbf{z})}[ \ln q(\mathbf{z}) - \ln p(\mathbf{z}|\mathbf{x} ; \theta)] \\\\ & = \mathbb{E}\_{q(\mathbf{z})}[ \ln q(\mathbf{z}) - \ln p(\mathbf{x}, \mathbf{z} ; \theta) + \ln p(\mathbf{x} ; \theta)] \\\\ & = \ln p(\mathbf{x} ; \theta) - \mathcal{L}(q(\mathbf{z}), \theta) \end{aligned} $$ where $$\mathcal{L}(q(\mathbf{z}), \theta) = \mathbb{E}\_{q(\mathbf{z})}[\ln p(\mathbf{x}, \mathbf{z}; \theta) - \ln q(\mathbf{z})]$$ is the **evidence lower bound** (ELBO) that we already encountered in the EM algorithm. It is also called the (negative) **variational free energy**. --- class: middle The ELBO can be further decomposed as: $$\mathcal{L}(q(\mathbf{z}), \theta) = E(q(\mathbf{z}), \theta) + H(q(\mathbf{z})),$$ where $E(q(\mathbf{z}), \theta) = \mathbb{E}\_{q(\mathbf{z})}[\ln p(\mathbf{x}, \mathbf{z}; \theta)]$ and $ H(q(\mathbf{z})) = - \mathbb{E}\_{q(\mathbf{z})}[\ln q(\mathbf{z})]$ is the differential entropy of $q(\mathbf{z})$ which does not depend on the model parameters $\theta$. --- Variational inference consists in **solving the following optimization problem**: $$ \begin{aligned} & \underset{q \in \mathcal{F}}{\min}\hspace{.25cm} D\_{\text{KL}}(q(\mathbf{z}) \parallel p(\mathbf{z}|\mathbf{x} ; \theta)) \qquad \Leftrightarrow \qquad \underset{q \in \mathcal{F}}{\max}\hspace{.25cm} \mathcal{L}(q(\mathbf{z}), \theta). \end{aligned} $$ -- count: false If the variational family $\mathcal{F}$ is not constrained (i.e. it is the set of all pdfs over $\mathbf{z}$), we have: $$ \begin{aligned} q^\star(\mathbf{z}) &= \underset{q \in \mathcal{F}}{\arg\min}\hspace{.25cm} D\_{\text{KL}}(q(\mathbf{z}) \parallel p(\mathbf{z}|\mathbf{x} ; \theta)) \\\\ &= \underset{q \in \mathcal{F}}{\arg\max}\hspace{.25cm} \mathcal{L}(q(\mathbf{z}), \theta) \\\\ &= p(\mathbf{z}|\mathbf{x} ; \theta), \end{aligned} $$ which corresponds to the E-step of the EM algorithm for exact inference... -- count: false ... but our starting hypothesis was **"the true posterior is analytically intractable"**, so we need to induce some constraints on the variational distribution $q(\mathbf{z})$, through the definition of the variational family $\mathcal{F}$. --- class: middle Our goal is to restrict the family sufficiently such that it comprises only tractable distributions. But at the same time we want the family to be sufficiently rich and flexible such that it can provide a good approximation to the true posterior distribution. It is important to emphasize that the restriction is imposed purely to achieve tractability, and that subject to this requirement we should use a family of approximating distributions as rich as possible. --- class: center, middle ## Mean-field variational inference --- The **mean field approximation** defines the variational family $\mathcal{F}$ as the set of pdfs that can be factorized as follows: $$ q(\mathbf{z}) = \prod\_{i=1}^L q\_i(z\_i), $$ where $\mathbf{z} = \\{z\_i\\}\_{i=1}^L$. The mean field approximation assumes that the individual scalar latent variables are independent .italic[a posteriori], that is for all $(i,j)$ with $i \neq j$, $$q(z\_i, z\_j) = q\_i(z\_i) q\_j(z\_j),$$ even though this may not hold for the true posterior: $$p(z\_i, z\_j | \mathbf{x}; \theta) \neq p(z\_i | \mathbf{x}; \theta) p(z\_j | \mathbf{x}; \theta).$$ It should be emphasized that we are making no further assumptions about the distribution. In particular, we place no restriction on the functional forms of the individual factors $q\_i(z\_i)$. --- class: middle Among all distributions $q(\mathbf{z})$ that factorize as in the mean-field (MF) approximation, we now seek the one that maximizes the ELBO $\mathcal{L}(q(\mathbf{z}), \theta)$. Let's inject the MF factorization into the definition of the ELBO: $$ \begin{aligned} \mathcal{L}(q(\mathbf{z}); \theta) =& \mathbb{E}\_{q(\mathbf{z})}[\ln p(\mathbf{x}, \mathbf{z}; \theta) - \ln q(\mathbf{z})] \\\\ =& \int \prod\limits\_{i=1}^{L} q\_i(z\_i) \left[ \ln p(\mathbf{x},\mathbf{z} ; \theta) - \ln\left(\prod\limits\_{i=1}^{L} q\_i(z\_i)\right) \right] d\mathbf{z} \\\\[.5cm] =& \,... \text{ \footnotesize (see derivation details in the supporting document)} \\\\[.2cm] =& -D\_{\text{KL}}(q\_j(z\_j) \parallel \tilde{p}(\mathbf{x}, z\_j; \theta)) - \sum\_{i \neq j} \mathbb{E}\_{q\_i(z\_i)}\left[\ln q\_i(z\_i)\right], \end{aligned} $$ where $\ln \tilde{p}(\mathbf{x}, z\_j ; \theta) = \mathbb{E}\_{\prod\_{i \neq j} q\_i(z\_i)}\left[ \ln p(\mathbf{x},\mathbf{z} ; \theta ) \right] + cst.$ .footnote[The constant ensures that the distribution integrates to one.] --- class: middle We adopt a **coordinate ascent** approach, where we **alternatively maximize** $\mathcal{L}(q(\mathbf{z}); \theta)$ with respect to each individual factor $q\_j(z\_j)$ considering the other ones $\\{q\_i(z\_i)\\}_{i \neq j}$ fixed. -- count: false From the previous expression of the ELBO, we have: $$ q\_j^\star(z\_j) = \underset{q\_j(z\_j)}{\arg\max}\hspace{.25cm} \mathcal{L}(q(\mathbf{z}); \theta) = \underset{q\_j(z\_j)}{\arg\min}\hspace{.25cm} D\_{\text{KL}}(q\_j(z\_j) \parallel \tilde{p}(\mathbf{x}, z\_j; \theta)).$$ -- count: false The optimal distribution which minimizes the KL divergence is therefore given by: $$ \ln q\_j^\star(z\_j) = \ln \tilde{p}(\mathbf{x}, z\_j; \theta) = \mathbb{E}\_{\prod\_{i \neq j} q\_i(z\_i)}\left[ \ln p(\mathbf{x},\mathbf{z} ; \theta ) \right] + cst, $$ --- class: middle The constant can be determined by normalizing $q\_j^\star(z\_j)$ such that it integrates to one: $$ q\_j^\star(z\_j) = \frac{\exp\left(\mathbb{E}\_{\prod\_{i \neq j} q\_i(z\_i)}\left[ \ln p(\mathbf{x},\mathbf{z} ; \theta ) \right]\right)}{\int \exp\left(\mathbb{E}\_{\prod\_{i \neq j} q\_i(z\_i)}\left[ \ln p(\mathbf{x},\mathbf{z} ; \theta ) \right]\right) dz\_j}. $$ However, usually we simply **develop** $\mathbb{E}\_{\prod\_{i \neq j} q\_i(z\_i)}\left[ \ln p(\mathbf{x},\mathbf{z} ; \theta ) \right]$ and **identify** the form of a common distribution (e.g. Gaussian, inverse-gamma, etc.) --- class: middle The optimal distribution $q\_j^\star(z\_j)$ depends on the other factors $q\_i(z\_i)$, $i \neq j$, involved in the MF approximation. The **solutions for different indices are therefore coupled**. A **consistant global solution is obtained iteratively**, by first initializing all the factors and then cycling over each individual one to compute the update. --- class: middle, center ## Example: mean-field approximation of the bivariate Gaussian --- class: middle We consider the problem of approximating a Gaussian distribution using a factorized Gaussian. It will provide useful insights into the types of inaccuracy introduced by the mean-field approximation. Consider a bivariate Gaussian random vector $\mathbf{z} = [z\_1, z\_2]^\top$ such that $$p(\mathbf{z}; \theta) = \mathcal{N}\left(\mathbf{z} ; \boldsymbol{\mu}, \boldsymbol{\Lambda}^{-1} \right),$$ where the **parameters $\theta$ are assumed to be known** and correspond to the **mean vector and precision matrix** (inverse of the covariance matrix) which are structured as $$ \boldsymbol{\mu} = \begin{pmatrix} \mu\_1 \\\\ \mu\_2 \end{pmatrix}, \qquad\qquad \boldsymbol{\Lambda} = \begin{pmatrix} \Lambda\_{11} & \Lambda\_{12} \\\\ \Lambda\_{21} & \Lambda\_{22} \end{pmatrix},$$ and $\Lambda\_{21} = \Lambda\_{12}$ due to the symmetry of the precision matrix. --- class: middle Suppose now that, under the mean field approximation, we want to find a factorized variational distribution $$q(\mathbf{z}) = q\_1(z\_1)q\_2(z\_2)$$ which approximates $$p(\mathbf{z}; \theta)$$ using the reverse KL divergence as a measure of discrepancy. --- class: middle We have seen that: 1. minimizing $D\_{\text{KL}}(q(\mathbf{z}) \parallel p(\mathbf{z}; \theta))$ w.r.t $q(\mathbf{z})$ is equivalent to maximizing the ELBO, 2. under the mean field approximation, the optimal factor should satisfy $$ \ln q\_j^\star(z\_j) = \mathbb{E}\_{q\_i(z\_i)}\left[ \ln p(\mathbf{z}; \theta) \right] + cst, \qquad j\in \\{1,2\\}, i \neq j.$$ We now have to develop this expression, **ignoring all the terms that do not depend on $z\_j$**, because they can be absorbed into the normalization constant. Let us focus on $q\_1^\star(z\_1)$, as $q\_2^\star(z\_2)$ can simply be obtained by symmetry. .footnote[The complete-data likelihood is by definition is the joint pdf of the observed and latent variables, what we denoted by $p(\mathbf{x}, \mathbf{z}; \theta)$ before. In the current example, we only have observed variables and the complete-data likelihood simply corresponds to $p(\mathbf{z} ; \theta)$.] --- class: middle .small-vspace[ ] $$ \begin{aligned} \ln q\_1^\star(z\_1) &= \mathbb{E}\_{q\_2(z\_2)}\left[ \ln p(\mathbf{z}; \theta) \right] + cst \\\\ &= \mathbb{E}\_{q\_2(z\_2)}\left[ \ln \mathcal{N}\left( \begin{pmatrix} z\_1 \\\\ z\_2 \end{pmatrix}; \begin{pmatrix} \mu\_1 \\\\ \mu\_2 \end{pmatrix}, \begin{pmatrix} \Lambda\_{11} & \Lambda\_{12} \\\\ \Lambda\_{21} & \Lambda\_{22} \end{pmatrix}^{-1} \right) \right] + cst \\\\ &= \mathbb{E}\_{q\_2(z\_2)}\left[ -\frac{1}{2} \begin{pmatrix} z\_1 - \mu\_1 \\\\ z\_2 - \mu\_2 \end{pmatrix}^\top \begin{pmatrix} \Lambda\_{11} & \Lambda\_{12} \\\\ \Lambda\_{21} & \Lambda\_{22} \end{pmatrix} \begin{pmatrix} z\_1 - \mu\_1 \\\\ z\_2 - \mu\_2 \end{pmatrix} \right] + cst \\\\ &= \mathbb{E}\_{q\_2(z\_2)}\left[-\frac{1}{2} \Big( (z\_1 - \mu\_1)^2 \Lambda\_{11} + 2(z\_1 - \mu\_1)\Lambda\_{12}(z\_2 - \mu\_2) \Big) \right] + cst \\\\ &= \mathbb{E}\_{q\_2(z\_2)}\left[ -\frac{1}{2} z\_1^2 \Lambda\_{11} + z\_1\Big(\mu\_1 \Lambda\_{11} - \Lambda\_{12}(z\_2 - \mu\_2) \Big) \right] + cst \\\\ &= -\frac{1}{2} z\_1^2 \Lambda\_{11} + z\_1\Big(\mu\_1 \Lambda\_{11} - \Lambda\_{12}(\mathbb{E}\_{q\_2(z\_2)}\left[z\_2 \right] - \mu\_2) \Big) + cst \end{aligned} $$ --- Let's recall the result: $$\ln q\_1^\star(z\_1) = -\frac{1}{2} z\_1^2 \Lambda\_{11} + z\_1\Big(\mu\_1 \Lambda\_{11} - \Lambda\_{12}(\mathbb{E}\_{q\_2(z\_2)}\left[z\_2 \right] - \mu\_2) \Big) + cst, $$ This is a **quadratic function** of $z\_1$, so the optimal distribution is a **Gaussian distribution** $q\_1^\star(z\_1) = \mathcal{N}(m\_1, \gamma\_1^{-1})$. The mean and precision can be determined by **identification**: $$ \begin{aligned} \ln q\_1^\star(z\_1) &= \ln \mathcal{N}(m\_1, \gamma\_1^{-1}) \\\\ &= -\frac{1}{2} (z\_1 - m\_1)^2 \gamma\_1 + cst \\\\ &= -\frac{1}{2} z\_1^2 \gamma\_1 + z\_1 m\_1 \gamma\_1 + cst, \end{aligned} $$ and we identify: $$\gamma\_1 = \Lambda\_{11}, \qquad\qquad m\_1 = \mu\_1 - \Lambda\_{11}^{-1}\Lambda\_{12}(\mathbb{E}\_{q\_2(z\_2)}\left[z\_2 \right] - \mu\_2).$$ --- class: middle It is important to emphasize that **we did not assume that $q\_1(z\_1)$ is Gaussian**. We obtained this result by optimizing the KL divergence under the mean field approximation, which is the only assumption we made. Note also that **we did not compute the normalizing constant** for $q\_1(z\_1)$ explicitely, we simply recognized the form of a known distribution (Gaussian), which implicitely gives us the normalizing constant. --- class: middle By symmetry we also have: $$ q\_2^\star(z\_2) = \mathcal{N}(m\_2, \gamma\_2^{-1}),$$ with $$\gamma\_2 = \Lambda\_{22}, \qquad\qquad m\_2 = \mu\_2 - \Lambda\_{22}^{-1}\Lambda\_{21}(\mathbb{E}\_{q\_1(z\_1)}\left[z\_1 \right] - \mu\_1).$$ --- class: middle To sum up, from an **initialization** of $q\_1^\star(z\_1)$ and $q\_2^\star(z\_2)$ we **iterate**: $$ \begin{aligned} q\_1^\star(z\_1) &= \mathcal{N}\Big( \underbrace{\mu\_1 - \Lambda\_{11}^{-1}\Lambda\_{12}\big( {\color{brown}\mathbb{E}\_{q\_2^\star(z\_2)}\left[z\_2 \right]} - \mu\_2\big)}\_{m\_1},\, \Lambda\_{11}^{-1}\Big), \\\\[.3cm] q\_2^\star(z\_2) &= \mathcal{N}\Big(\underbrace{\mu\_2 - \Lambda\_{22}^{-1}\Lambda\_{21}\big( {\color{brown}\mathbb{E}\_{q\_1^\star(z\_1)}\left[z\_1 \right]} - \mu\_1\big)}\_{m\_2},\, \Lambda\_{22}^{-1}\Big), \end{aligned} $$ with $${\color{brown}\mathbb{E}\_{q\_2^\star(z\_2)}\left[z\_2 \right]} = m\_2, \qquad\qquad {\color{brown}\mathbb{E}\_{q\_1^\star(z\_1)}\left[z\_1 \right]} = m\_1. $$ It is clear that these solutions are coupled, as $q\_1^\star(z\_1)$ depends on an expectation computed with respect to $q\_2^\star(z\_2)$ and vice versa. .footnote[More precisely, we iterate the updates of the variational parameters, i.e., the parameters of the variational distributions.] --- After convergence, we can compare the resulting approximation $q(\mathbf{z}) = q\_1(z\_1)q\_2(z\_2)$ with the original distribution $p(\mathbf{z}; \theta)$. .center.width-30[![](images/MF_approx.png)] .caption[Green: original distribution, red: mean field approximation] - it captures the mean correctly, - the variance is underestimated (due to the choice of the reverse KL), - the elongated shape is missing (by construction of the mean field approximation). .credit[Image credit: Christopher M. Bishop, [Pattern Recognition and Machine Learning](https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf), Springer, 2006.] --- class: middle, center ## Exercise ### 1D Gaussian with latent mean and variance --- class: middle ### Problem Consider a dataset $\mathbf{x} = \\{x\_1, x\_2, ..., x\_N\\}$ of i.i.d realizations of a univariate Gaussian random variable $x \sim \mathcal{N}(\mu, \tau^{-1})$. The mean $\mu$ and precision $\tau$ are modeled as latent random variables. We are interested in infering their posterior distribution, given the observations $\mathbf{x}$. --- class: middle ### Generative model - Gaussian likelihood: $$ \begin{aligned} p(\mathbf{x} | \mu, \tau\) &= \prod\_{i=1}^N p(x\_i | \mu, \tau) = \prod\_{i=1}^N \mathcal{N}(x\_i ; \mu, \tau^{-1}) = \left(\frac{\tau}{2 \pi}\right)^{N/2} \exp\left( -\frac{\tau}{2} \sum\_{i=1}^N (x\_i - \mu)^2 \right). \end{aligned} $$ - Gaussian prior for the mean (conjugate): $$ \begin{aligned} p(\mu | \tau\) &= \mathcal{N}\left(\mu ; \mu\_0, (\lambda\_0 \tau)^{-1} \right) = \left(\frac{\lambda\_0 \tau}{2 \pi}\right)^{1/2} \exp\left( -\frac{\lambda\_0 \tau}{2} (\mu - \mu\_0)^2 \right). \end{aligned} $$ - Gamma prior for the precision (conjugate): $$ \begin{aligned} p(\tau\) &= \mathcal{G}\left(\tau ; a\_0, b\_0 \right) = \frac{b\_0^{a\_0}}{\Gamma(a\_0)} \tau^{(a\_0 - 1)} \exp(-b\_0 \tau), \end{aligned} $$ where $\Gamma(\cdot)$ is the Gamma function. --- class: middle ### True posterior (homework) For this simple problem where the priors are conjugate for the likelihood, the posterior distribution can be found exactly, and it also takes the form of a Gaussian-gamma distribution. --- class: middle $$ p(\mu, \tau | \mathbf{x}) = p(\mu | \mathbf{x}, \tau) p(\tau | \mathbf{x}), $$ with $$ p(\mu | \mathbf{x}, \tau) = \mathcal{N}\left(\mu ; \mu\_\star, \lambda\_\star^{-1} \right), \qquad\qquad p(\tau | \mathbf{x} ) = \mathcal{G}\left(\tau ; \alpha, \beta \right),$$ - $\displaystyle \mu\_\star = \frac{N \tau}{N \tau + \lambda\_0 \tau} \bar{x} + \frac{\lambda\_0 \tau}{N \tau + \lambda\_0 \tau} \mu\_0, \qquad \bar{x} = \frac{1}{N} \sum\_{i=1}^N x\_i $ - $\lambda\_\star = N \tau + \lambda\_0 \tau $ - $\displaystyle \alpha = a\_0 + \frac{N}{2}$ - $\displaystyle \beta = b\_0 + \frac{1}{2} \sum\_{i=1}^N (x\_i - \bar{x})^2 + \frac{\lambda\_0 N}{2(\lambda\_0 + N)} (\bar{x} - \mu\_0)^2$ .footnote[The proof is quite involved, especially for $p(\tau | \mathbf{x})$, but it's a very good exercise. You can check [this document](https://people.eecs.berkeley.edu/~jordan/courses/260-spring10/lectures/lecture5.pdf) for some hints.] --- class: middle For practice purposes, we will consider an approximate posterior distribution using the mean field approximation: $$ q(\mu, \tau) = q\_\mu(\mu) q\_\tau(\tau).$$ Note that as shown in the previous slide, the **true posterior does not factorize like this**. --- class: middle .center.bold[Exercise] --- Using the variational inference recipe, show that the optimal factors $q\_\mu^\star(\mu)$ and $q\_\tau^\star(\tau)$ are given by: $$ q\_\mu^\star(\mu) = \mathcal{N}(\mu; \mu\_N, \lambda\_N^{-1}), \qquad \qquad q\_\tau^\star(\tau) = \mathcal{G}(\tau; a\_N, b\_N), $$ where $\lambda\_N = \mathbb{E}\_{q\_\tau(\tau)}[\tau] (\lambda\_0 + N)$, $\mu\_N = \displaystyle \frac{\lambda\_0 \mu\_0 + N \bar{x}}{\lambda\_0 + N}$, $a\_N = a\_0 + (N+1)/2$, $\begin{aligned} \displaystyle b\_N &= b\_0 + \frac{1}{2} \mathbb{E}\_{q\_\mu(\mu)}\left[ \sum\_{i=1}^N (x\_i - \mu)^2 + \lambda\_0(\mu - \mu\_0)^2 \right] \\\\ &= b\_0 + \frac{1}{2} \left( \sum\_{i=1}^N x\_i^2 + \mathbb{E}\_{q\_\mu(\mu)}[ \mu^2 ] (\lambda\_0 + N) - 2\mathbb{E}\_{q\_\mu(\mu)}[\mu] (\lambda\_0\mu\_0 + N \bar{x}) + \lambda\_0 \mu\_0^2 \right).\end{aligned}$ --- class: middle Using the properties of the Gaussian and Gamma distributions, the required expectations are given by $$ \mathbb{E}\_{q\_\tau(\tau)}[\tau] = a\_N / b\_N, $$ $$ \mathbb{E}\_{q\_\mu(\mu)}[\mu] = \mu\_N, $$ $$ \mathbb{E}\_{q\_\mu(\mu)}[\mu^2] = \mu\_N^2 + \lambda\_N^{-1}. $$ --- class: center, middle .width-60[![](images/blackboard.jpg)] --- class: center, middle .width-20[![](images/jupyter.png)] --- class: middle, center ## Variational EM algorithm --- class: middle, center So far, we assumed that all the deterministic parameters of our model (likelihod and priors) are known. **What if they are not?** --- class: middle ### Generative model with latent variables (reminder) Let $\mathbf{x} \in \mathcal{X}$ and $\mathbf{z} \in \mathcal{Z}$ denote the **observed and latent** random variables, respectively. Developing a probabilistic model consists in defining the joint distribution of the observed and latent variables, also called **complete-data likelihood**: $$ p(\mathbf{x}, \mathbf{z}; \theta) = p(\mathbf{x} | \mathbf{z}; \theta) p(\mathbf{z}; \theta), $$ where $\theta$ is a set of **unknown deterministic parameters**. .footnote[to simplify notations we use $\theta$ to denote both the parameters of the prior and likelihood, but these two distributions usually depend on disjoint sets of parameters.] --- class: middle ### Maximum marginal likelihood estimation of the model parameters (reminder) $$ \hat{\theta} = \underset{\theta}{\arg\max}\hspace{.2cm} p(\mathbf{x}; {\theta}) = \underset{{\theta}}{\arg\max} \int p(\mathbf{x} | \mathbf{z} ; {\theta})p(\mathbf{z} ; {\theta}) d\mathbf{z}.$$ Quite often, directly solving the optimization problem associated with this ML estimation procedure is difficult, if not impossible when the marginal likelihood cannot be computed analytically. We have seen in a previous lecture that in this case, we can leverage the fact that we have latent variables to derive an **expectation-maximization** (EM) algorithm to estimate the model parameters. --- class: middle ### EM algorithm (reminder) The EM algorithm is an iterative algorithm which alternates between optimizing the ELBO $$ \mathcal{L}(q(\mathbf{z}), \theta) = \mathbb{E}\_{q(\mathbf{z})} [\ln p(\mathbf{x}, \mathbf{z}; \theta) - \ln q(\mathbf{z} )] $$ with respect to $q(\mathbf{z}) \in \mathcal{F}$ in the E-Step and with repspect to $\theta$ in the M-step. We first **initialize** $\theta^\star$, then we iterate: - **E-Step**: $ q^\star(\mathbf{z}) = \underset{q(\mathbf{z}) \in \mathcal{F}}{\arg\max}\, \mathcal{L}(q(\mathbf{z}), \theta^\star) $ - **M-Step**: $ \theta^\star = \underset{\theta}{\arg\max}\, \mathcal{L}(q^\star(\mathbf{z}), \theta) $ --- class: middle When the **family $\mathcal{F}$ is unconstrained**, the solution of the E-Step is given by the posterior distribution: $$ q^\star(\mathbf{z}) = p(\mathbf{z} | \mathbf{x}; \theta^\star). $$ But **what if this posterior is intractable**? -- count: false We have to constrain the family $\mathcal{F}$, typically with the **mean field approximation**. --- class: middle ### Variational EM algorithm with the mean field approximation Let the family $\mathcal{F}$ denote the set of probability density functions that can be factorized as: $$ q(\mathbf{z}) = \prod\_{i=1}^L q\_i(z\_i), \qquad\qquad \mathbf{z} = \\{z\_i\\}\_{i=1}^L.$$ Given an **initialization** $\theta^\star$, the variational EM (VEM) algorithm consists in iterating: - **E-Step**: $ q^\star(\mathbf{z}) = \underset{q(\mathbf{z}) \in \mathcal{F}}{\arg\max}\, \mathcal{L}(q(\mathbf{z}), \theta^\star) $ - **M-Step**: $ \theta^\star = \underset{\theta}{\arg\max}\, \mathcal{L}(q^\star(\mathbf{z}), \theta) $ We have seen that the solution of the E-Step consists in cyclically computing for $j=1,...,L$: $$ \ln q\_j^\star(z\_j) = \mathbb{E}\_{\prod\_{i \neq j} q\_i(z\_i)}\left[ \ln p(\mathbf{x},\mathbf{z} ; \theta^\star ) \right] + cst. $$