We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
VAE Reparameterization Trick
Implement the reparameterization trick used in Variational Autoencoders (VAEs).
The Problem
In a VAE, the encoder outputs a distribution q(z|x) = N(mu, sigma^2) over the
latent variable z. To train the model with gradient descent, we need to
backpropagate through a sample from this distribution.
Sampling is a stochastic operation — it has no gradient. The reparameterization trick sidesteps this by expressing the sample as a deterministic function of the parameters plus a fixed-distribution noise variable:
z = mu + sigma * eps, eps ~ N(0, I)
Now gradients flow through mu and sigma (both deterministic), while eps is
a random constant — no gradient needed through the sampling step.
Algorithm
sigma = exp(0.5 * log_var) # convert log-variance to std dev
eps = N(0, I) sampled with the given seed
z = mu + sigma * eps
Why It Works
-
muandsigmaare outputs of the encoder network — differentiable. -
epsis sampled independently of the network — treated as a constant during backprop. -
The KL divergence term in the ELBO loss also depends on
muandlog_var, so the full VAE loss is differentiable end-to-end.
Reference
Kingma & Welling, “Auto-Encoding Variational Bayes” (2013).
PRNG note
PyTorch and JAX use different pseudo-random number generators. Given the same
seed, they will produce different samples. The expected outputs for this
problem are generated using PyTorch only. Your JAX solution will be tested
for correctness of the algorithm (right distribution, right shape), not for
exact value matching against PyTorch’s samples.
Inputs / Output
-
mu: tensor of shape(N, d)— per-sample mean. -
log_var: tensor of shape(N, d)— per-sample log-variance. -
seed: int — random seed for reproducibility.
Output: z of shape (N, d).
Hints
Sign in to attempt this problem and view the solution.