medium end_to_end

Train VAE End-to-End

Perform one full VAE training step on a fully-connected encoder/decoder β€” forward pass, ELBO loss, and hand-computed gradients via the chain rule.

No torch.autograd. No loss.backward(). You derive every partial derivative yourself.

Architecture

A minimal linear VAE (no activation functions, for gradient-checking clarity):

  • Encoder: h = x @ w_enc gives (mu, log_var) concatenated along the last dimension. First d_z columns of h are mu; last d_z columns are log_var.
  • Reparameterization: sigma = exp(0.5 * log_var), sample eps ~ N(0, I) with the given seed, then z = mu + sigma * eps.
  • Decoder: x_hat = z @ w_dec.

Loss (negative ELBO)

recon_loss = ((x_hat - x)**2).sum(dim=-1).mean()          # mean over batch, sum over d
kl_loss    = -0.5 * (1 + log_var - mu**2 - exp(log_var)).sum(dim=-1).mean()
loss       = recon_loss + kl_loss

The KL term has a closed form for q(z|x) = N(mu, diag(exp(log_var))) vs p(z) = N(0, I). Reference: Kingma & Welling, β€œAuto-Encoding Variational Bayes” (2013).

Backward pass (by hand)

d_x_hat     = (2/N) * (x_hat - x)                                   # (N, d)
grad_w_dec  = z.T @ d_x_hat                                          # (d_z, d)
d_z         = d_x_hat @ w_dec.T                                      # (N, d_z)

# Through reparameterization z = mu + sigma * eps:
d_mu_recon     = d_z
d_log_var_recon = d_z * 0.5 * sigma * eps       # dz/d_log_var = 0.5 * sigma * eps

# KL contributions:
d_mu_kl     = mu / N                             # from -0.5 * (-2*mu) / N
d_log_var_kl = (exp(log_var) - 1) / (2*N)       # from -0.5 * (1 - exp(log_var)) / N

d_mu      = d_mu_recon + d_mu_kl
d_log_var = d_log_var_recon + d_log_var_kl

d_h        = concat(d_mu, d_log_var, dim=-1)     # (N, 2*d_z)
grad_w_enc = x.T @ d_h                           # (d, 2*d_z)

w_enc -= lr * grad_w_enc
w_dec -= lr * grad_w_dec

Inputs

  • x: shape (N, d) β€” input batch.
  • w_enc: shape (d, 2*d_z) β€” encoder weights; columns [:d_z] map to mu, [d_z:] to log_var.
  • w_dec: shape (d_z, d) β€” decoder weights.
  • lr: float β€” SGD learning rate.
  • seed: int β€” for reproducible reparameterization noise.

Output

Returns shape (d * 2*d_z + d_z * d,) β€” concatenation of (final_w_enc.flatten(), final_w_dec.flatten()).

PRNG note

PyTorch and JAX use different pseudo-random number generators. Given the same seed, they produce different eps samples, so their updated weights diverge. The expected outputs here are generated using PyTorch only (framework: :pytorch).

Hints

generative vae training

Sign in to attempt this problem and view the solution.