Variational Inference

Latent Variable Model

pθ(xz)p_\theta(x|z) ⇒ Latent Model (VAE) that transforms a proposed state zz to xx

We have data points xx, we want to model a distribution of xx with the help of latent variable zz.

How do we train latent variable models?

Estimating the log-likelihood

🧙🏽‍♂️
Use Expected Log-likelihood
θarg maxθ1NiEzp(zxi)[logpθ(xi,z)]\theta \leftarrow \argmax_\theta \frac{1}{N} \sum_i \mathbb{E}_{z \sim p(z|x_i)}[\log p_\theta(x_i,z)]

Intuition:

logp(xi)=logzp(xiz)p(z)=logzp(xiz)p(z)qi(z)qi(z)=logEzqi(z)[p(xiz)p(z)qi(z)]\begin{split} \log p(x_i) &= \log \int_z p(x_i|z) p(z) \\ &= \log \int_z p(x_i|z)p(z)\frac{q_i(z)}{q_i(z)} \\ &=\log \mathbb{E}_{z \sim q_i(z)} [\frac{p(x_i|z)p(z)}{q_i(z)}] \end{split}

Note:

Yensen’s Inequality (For concave function)

logE[y]E[logy]\log \mathbb{E}[y] \ge \mathbb{E}[\log y]

So

logp(xi)=logEzqi(z)[p(xiz)p(z)qi(z)]Ezqi(z)[logp(xiz)p(z)qi(z)]=Ezqi(z)[logp(xiz)+logp(z)]Ezqi(z)[logqi(z)]=Ezqi(z)[logp(xiz)+logp(z)]+H(qi)\begin{split} \log p(x_i) &=\log \mathbb{E}_{z \sim q_i(z)} [\frac{p(x_i|z)p(z)}{q_i(z)}] \\ &\ge \mathbb{E}_{z \sim q_i(z)} [\log \frac{p(x_i|z)p(z)}{q_i(z)}] \\ &\quad = \mathbb{E}_{z \sim q_i(z)}[\log p(x_i|z) + \log p(z)] - \mathbb{E}_{z \sim q_i(z)} [\log q_i(z)] \\ &\quad = \mathbb{E}_{z \sim q_i(z)}[\log p(x_i|z) + \log p(z)] + H(q_i) \end{split}

Note about entropy:

H(p)=Exp(x)[logp(x)]=xp(x)logp(x)dxH(p) = -\mathbb{E}_{x \sim p(x)}[\log p(x)] = -\int_x p(x)\log p(x) dx

Intuitions:

Also note about KL-Divergence:

DKL(qp)=Exq(x)[logq(x)p(x)]=Exq(x)[logq(x)]Exq(x)[logp(x)]=Exq(x)[logp(x)]H(q)D_{KL}(q ||p) = \mathbb{E}_{x \sim q(x)}[log \frac{q(x)}{p(x)}] = \mathbb{E}_{x \sim q(x)} [\log q(x)] - \mathbb{E}_{x \sim q(x)}[\log p(x)] = -\mathbb{E}_{x \sim q(x)}[\log p(x)] - H(q)

Intuitions:

So the variational approximation

logp(xi)Ezqi(z)[logp(xiz)+logp(z)]+H(qi)undefinedLi(p,qi)\begin{split} \log p(x_i) &\ge \overbrace{\mathbb{E}_{z \sim q_i(z)}[\log p(x_i|z) + \log p(z)] + H(q_i)}^{L_i(p,q_i)} \\ \end{split}

So what makes a good qi(z)q_i(z)?

Why?

DKL(qi(z)p(zxi))=Ezqi(z)[logqi(z)p(zxi)]=Ezqi(z)[logqi(z)p(xi)p(xi,z)]=Ezqi(z)[logp(xiz)]+logp(z)]+Ezqi(z)[logqi(z)]+Ezqi(z)[logp(xi)]=Ezqi(z)[logp(xiz)]+logp(z)]H(qi)+Ezqi(z)[logp(xi)]=Li(p,qi)+logp(xi)\begin{split} D_{KL}(q_i(z) || p(z|x_i)) &= \mathbb{E}_{z \sim q_i(z)}[\log \frac{q_i(z)}{p(z|x_i)}] = \mathbb{E}_{z \sim q_i(z)}[\log \frac{q_i(z) p(x_i)}{p(x_i,z)}] \\ &= - \mathbb{E}_{z \sim q_i(z)}[\log p(x_i|z)] + \log p(z)] + \mathbb{E}_{z \sim q_i(z)}[\log q_i(z)] + \mathbb{E}_{z \sim q_i(z)}[\log p(x_i)] \\ &= - \mathbb{E}_{z \sim q_i(z)}[\log p(x_i|z)] + \log p(z)] - H(q_i) + \mathbb{E}_{z \sim q_i(z)}[\log p(x_i)] \\ &=- L_i(p,q_i) + \log p(x_i) \end{split}

Also:

logp(xi)=DKL(qi(xi)p(zxi))+Li(p,qi)\log p(x_i) = D_{KL}(q_i(x_i)||p(z|x_i)) + L_i(p,q_i)

In fact minimizing the KL-divergence has the effect of tightening then bound!

But how to improve qiq_i?

Amortized Variational Inference

Li=Ezqϕ(zxi)[logpθ(xiz)+logp(z)]undefinedJ(ϕ)=Ezqϕ(zxi)[r(xi,z)]+H(qϕ(zxi))L_i = \underbrace{\mathbb{E}_{z \sim q_\phi(z|x_i)}[\log p_\theta(x_i|z) + \log p(z)]}_{J(\phi) = \mathbb{E}_{z \sim q_\phi(z|x_i)}[r(x_i,z)]} + H(q_\phi(z|x_i))

We can directly calculate the derivative for the entropy term but the first term J(ϕ)J(\phi) seems a bit problematic.

But this format looks a lot like policy gradient!

J(ϕ)1Mjϕlogqϕ(zjxi)r(xi,zj)\nabla J(\phi) \approx \frac{1}{M} \sum_j \nabla_\phi \log q_\phi(z_j|x_i) r(x_i,z_j)

What’s wrong with this gradient?

The reparameterization trick

⚠️
In RL we cannot use this trick because we cannot calculate gradient through the transition dynamics, but with variational inference we can
qϕ(zx)=N(μϕ(x),σϕ(x))z=μϕ(x)+ϵσϕ(x),ϵN(0,1)q_\phi(z|x) = N(\mu_\phi(x), \sigma_\phi(x)) \\ z = \mu_\phi(x) + \epsilon \sigma_\phi(x), \epsilon \sim N(0,1)

This makes zz an independent function of a random variable ϵ\epsilon that is independent of ϕ\phi

So now:

J(ϕ)=Ezqϕ(zxi)[r(xi,z)]=EϵN(0,1)[r(xi,μϕ(xi)+ϵσϕ(xi))]\begin{split} J(\phi) &= \mathbb{E}_{z \sim q_\phi(z|x_i)}[r(x_i,z)] \\ &=\mathbb{E}_{\epsilon \sim N(0,1)}[r(x_i,\mu_\phi(x_i)+\epsilon \sigma_\phi(x_i))] \end{split}

And now to estimate ϕJ(ϕ)\nabla_{\phi} J(\phi)

This is useful because now we are using the derivative of the rr function

Another Perspective of LiL_i

Li=Ezqϕ(zxi)[logpθ(xiz)+logp(z)]+H(qϕ(zxi))=Ezqϕ(zxi)[logpθ(xiz)]+Ezqϕ(zxi)[logp(z)]+H(qϕ(zxi))undefinedDKL(qϕ(zxi)p(z))=Ezqϕ(zxi)[logpθ(xiz)]DKL(qϕ(zxi)p(z))=EϵN(0,1)[logpθ(xiμϕ(xi)+ϵσϕ(xi))]DKL(qϕ(zxi)p(z))logpθ(xiμϕ(xi)+ϵσϕ(xi))DKL(qϕ(zxi)p(z))\begin{split} L_i &= \mathbb{E}_{z \sim q_\phi(z|x_i)}[\log p_\theta(x_i|z) + \log p(z)] + H(q_\phi(z|x_i)) \\ &=\mathbb{E}_{z \sim q_\phi(z|x_i)}[\log p_\theta(x_i|z)] + \underbrace{\mathbb{E}_{z \sim q_\phi(z|x_i)}[\log p(z)] + H(q_\phi(z|x_i))}_{-D_{KL}(q_\phi(z|x_i)||p(z))} \\ &= \mathbb{E}_{z \sim q_\phi(z|x_i)}[\log p_\theta(x_i|z)]-D_{KL}(q_\phi(z|x_i)||p(z)) \\ &=\mathbb{E}_{\epsilon \sim N(0,1)}[\log p_\theta(x_i|\mu_\phi(x_i) + \epsilon \sigma_\phi(x_i))]-D_{KL}(q_\phi(z|x_i)||p(z)) \\ &\approx \log p_\theta(x_i | \mu_\phi(x_i) + \epsilon \sigma_\phi (x_i))-D_{KL}(q_\phi(z|x_i)||p(z)) \\ \end{split}

Reparameterization trick vs. Policy Gradient

Example Models

VAE Variational Autoencoder

maxθ,ϕ1Nilogpθ(xiμϕ(xi)+ϵσϕ(xi))DKL(qϕ(zxi)p(z))\max_{\theta, \phi} \frac{1}{N} \sum_i \log p_\theta(x_i|\mu_\phi(x_i) + \epsilon \sigma_\phi(x_i)) - D_{KL}(q_\phi(z|x_i)||p(z))

Conditional Models

Generates yy given x,zx, z

Li=Ezqϕ(zxi,yi)[logpθ(yixi,z)+logp(zxi)]+H(qϕ(zxi,yi))L_i = \mathbb{E}_{z \sim q_\phi(z|x_i, y_i)}[\log p_\theta(y_i|x_i,z) + \log p(z|x_i)]+H(q_\phi(z|x_i,y_i))