Latent Variable Model
We have data points x, we want to model a distribution of x with the help of latent variable z.
How do we train latent variable models?
- Model pθ(x)
- Data: D={x1,x2,x3,…,xN}
- MLE fit: θ←argmaxθN1∑ilogpθ(xi)
- p(x)=∫p(x∣z)p(z)dz
- θ←argmaxθN1∑ilog(completely untractable∫pθ(xi∣z)p(z)dz)
Estimating the log-likelihood
Intuition:
- Guess most likely z given xi, and pretend it is the right one
- But there are many possible values of z so use the distribution p(z∣xi)
- But how do we calculate p(z∣xi)?
- Approximate with qi(z)=N(μi,σi)
- Note that with each data point xi, they all have different qi
- With any qi(z), we can construct a lower bound on logp(xi)
- Maximize this bound (assume this bound is tight), we can push up on logp(xi)
Note:
Yensen’s Inequality (For concave function)
So
Note about entropy:
Intuitions:
- How random is the random variable?
- How large is the log probability in expectation under itself?
Also note about KL-Divergence:
Intuitions:
- How different are two distributions?
- How small is the expected log probability of one distribution under another, minus entropy?
So the variational approximation
So what makes a good qi(z)?
- qi(z) should approximate p(z∣xi)
- Compare in terms of KL-divergence DKL(qi(z)∣∣p(z∣x))
Why?
Also:
In fact minimizing the KL-divergence has the effect of tightening then bound!
But how to improve qi?
- We can let qi=N(μi,σi) and use gradient ∇μiLi(p,qi) and ∇σiLi(p,qi)
- But too many parameters! We have ∣θ∣+(∣μi∣+∣σi∣)×N parameters
- Then can we use
- qi(z)=qϕ(z∣xi)≈p(z∣xi)?
- Amortized Variational Inference!
Amortized Variational Inference
We can directly calculate the derivative for the entropy term but the first term J(ϕ) seems a bit problematic.
But this format looks a lot like policy gradient!
What’s wrong with this gradient?
- Perfectly viable approach, but not best approach
- Tends to suffer from high variance
The reparameterization trick
This makes z an independent function of a random variable ϵ that is independent of ϕ
So now:
And now to estimate ∇ϕJ(ϕ)
- sample ϵ1,…,ϵM from N(0,1) (even a single sample works well)
- ϵ treated as a constant when computing gradient
- ∇ϕJ(ϕ)≈M1∑j∇ϕr(xi,μϕ(xi)+ϵjσϕ(xi))
This is useful because now we are using the derivative of the r function
Another Perspective of Li
Reparameterization trick vs. Policy Gradient
- Policy Gradient
- Can handle both discrete and continuous latent variables
- High variance, requires multiple samples & small learning rates
- Reparameterization Trick
- Only continuous latent variables
- Simple to implement & Low variance
Example Models
VAE Variational Autoencoder
Conditional Models
Generates y given x,z