medium primitives

VAE ELBO Loss

Compute the negative ELBO loss used to train a Variational Autoencoder (VAE).

Background

The VAE objective is to maximise the Evidence Lower BOund (ELBO):

ELBO = E[log p(x|z)] - KL(q(z|x) || p(z))

We minimise the negative ELBO, which decomposes into two terms:

  • Reconstruction loss β€” how well the decoder reconstructs the input.
  • KL divergence β€” how far the encoder posterior q(z|x) = N(mu, diag(exp(log_var))) is from the prior N(0, I).

Algorithm

reconstruction_loss = mean over batch of sum-squared-error per example
                    = ((reconstructed - original) ** 2).sum(dim=-1).mean()

kl_loss = -0.5 * mean over batch of sum over d_z of (1 + log_var - mu^2 - exp(log_var))

total_loss = reconstruction_loss + kl_loss

The KL term has a closed form when both distributions are Gaussian. For q(z|x) = N(mu, sigma^2 I) and p(z) = N(0, I):

KL = -0.5 * sum_j (1 + log(sigma_j^2) - mu_j^2 - sigma_j^2)
   = -0.5 * sum_j (1 + log_var_j   - mu_j^2 - exp(log_var_j))

Reference

Kingma & Welling, β€œAuto-Encoding Variational Bayes” (2013).

Inputs / Output

  • reconstructed: shape (N, d) β€” decoder output.
  • original: shape (N, d) β€” original input.
  • mu: shape (N, d_z) β€” encoder mean.
  • log_var: shape (N, d_z) β€” encoder log variance.

Output: scalar float β€” the total loss.

Hints

generative vae loss

Sign in to attempt this problem and view the solution.