easy primitives

VAE Reparameterization Trick

Implement the reparameterization trick used in Variational Autoencoders (VAEs).

The Problem

In a VAE, the encoder outputs a distribution q(z|x) = N(mu, sigma^2) over the latent variable z. To train the model with gradient descent, we need to backpropagate through a sample from this distribution.

Sampling is a stochastic operation — it has no gradient. The reparameterization trick sidesteps this by expressing the sample as a deterministic function of the parameters plus a fixed-distribution noise variable:

z = mu + sigma * eps,   eps ~ N(0, I)

Now gradients flow through mu and sigma (both deterministic), while eps is a random constant — no gradient needed through the sampling step.

Algorithm

sigma   = exp(0.5 * log_var)      # convert log-variance to std dev
eps     = N(0, I)  sampled with the given seed
z       = mu + sigma * eps

Why It Works

  • mu and sigma are outputs of the encoder network — differentiable.
  • eps is sampled independently of the network — treated as a constant during backprop.
  • The KL divergence term in the ELBO loss also depends on mu and log_var, so the full VAE loss is differentiable end-to-end.

Reference

Kingma & Welling, “Auto-Encoding Variational Bayes” (2013).

PRNG note

PyTorch and JAX use different pseudo-random number generators. Given the same seed, they will produce different samples. The expected outputs for this problem are generated using PyTorch only. Your JAX solution will be tested for correctness of the algorithm (right distribution, right shape), not for exact value matching against PyTorch’s samples.

Inputs / Output

  • mu: tensor of shape (N, d) — per-sample mean.
  • log_var: tensor of shape (N, d) — per-sample log-variance.
  • seed: int — random seed for reproducibility.

Output: z of shape (N, d).

Hints

generative vae

Sign in to attempt this problem and view the solution.