We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ELBO for Gaussian VI
Why this matters
The Evidence Lower BOund (ELBO) is the training objective of variational inference and the VAE. Maximizing the ELBO is equivalent to minimizing KL(q(z|x) || p(z|x)) — fitting our variational approximation q to the true posterior p.
The ELBO decomposes as:
ELBO = E_q[log p(x|z)] - KL(q(z) || p(z))
For a Gaussian variational posterior q(z) = N(μ, σ²I) and a standard normal prior p(z) = N(0, I), the KL divergence has a closed-form expression:
KL = 0.5 * Σ_d [ σ_d² + μ_d² - 1 - log(σ_d²) ]
= 0.5 * Σ_d [ exp(2·log_σ_d) + μ_d² - 1 - 2·log_σ_d ]
The reconstruction term is approximated by a single-sample Monte Carlo estimate using the reparameterization trick: z = μ + σ·ε, ε ~ N(0,I). For this deterministic problem, ε is fixed at 0 so z = μ.
Worked mini-example
import jax.numpy as jnp
mu = jnp.array([0.0, 0.0])
log_sigma = jnp.array([0.0, 0.0]) # σ = 1 → unit Gaussian
x = jnp.array([0.0, 0.0])
# z = mu = [0, 0]
log_lik = -0.5 * jnp.sum((x - mu) ** 2) # = 0.0
kl = 0.5 * jnp.sum(jnp.exp(2 * log_sigma) + mu ** 2 - 1 - 2 * log_sigma)
# kl = 0.5 * (1 + 0 - 1 - 0) * 2 = 0.0
elbo = log_lik - kl # = 0.0
When q equals the prior and the reconstruction is perfect, the ELBO is 0.
Common pitfalls
- KL sign: ELBO = log_lik − KL. Forgetting the minus sign maximizes the wrong objective.
-
Variance vs log-std parameterization: we store
log_sigma(unconstrained), so σ² = exp(2·log_sigma), not exp(log_sigma). - Single-sample MC: using z = mu (ε = 0) is for reproducibility only. Real VAEs sample ε from N(0,I) and average over multiple samples.
- Missing constants: Gaussian log-pdf has a constant of -0.5·log(2π) per dimension. We drop it here (constant w.r.t. parameters) — this is standard practice in VAE training.
Problem
Implement gaussian_elbo(mu, log_sigma, x) that computes the single-sample
ELBO for a Gaussian variational posterior with a standard normal prior.
mu, log_sigma, and x are 1-D JAX arrays of the same shape (d,).
Return a scalar float.
Formulas (ε = 0, z = mu):
log_likelihood = -0.5 * sum((x - mu)^2)
kl = 0.5 * sum(exp(2 * log_sigma) + mu^2 - 1 - 2 * log_sigma)
elbo = log_likelihood - kl
Hints
Sign in to attempt this problem and view the solution.