We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Reparameterization Trick: Gaussian
Why this matters
The reparameterization trick is what makes VAE encoders trainable. A naïve
sample z ~ N(mu, sigma^2) is a stochastic node that blocks gradient flow
back through mu and sigma. The reparam trick rewrites the sample as
z = mu + sigma * eps, where eps ~ N(0, 1) is the only stochastic node.
Because eps carries no parameters, gradients flow cleanly through mu and
sigma — enabling ELBO optimization via standard backprop.
This exact pattern appears in every VAE encoder: the network outputs mu and
log_sigma (unconstrained), you exponentiate to recover sigma, then apply
the affine transform to a noise sample.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
mu, log_sigma = 2.0, 0.5
eps = jax.random.normal(key, (1,)) # N(0,1) — the stochastic node
sigma = jnp.exp(log_sigma) # recover sigma from unconstrained param
z = mu + sigma * eps # differentiable w.r.t. mu and log_sigma
# z is a sample from N(2.0, exp(0.5)^2) ≈ N(2.0, 1.65)
Common pitfalls
-
Parameterize as
log_sigma, notsigma: log_sigma is unconstrained (any real number), while sigma must be positive. Recover viaexp(log_sigma). -
The noise must be separate: draw
eps ~ N(0, 1), then scale. Do NOT draw fromN(mu, sigma^2)directly — that blocks gradients. -
Shape is
(int(n),):nmay arrive as a float; cast before use. -
PRNGKeyneeds a Python int:jax.random.PRNGKey(int(seed)).
Problem
Implement reparam_normal(seed, mu, log_sigma, n) that returns n samples
from N(mu, exp(log_sigma)^2) using the reparameterization trick.
One illustrative example (not from the test set):
-
reparam_normal(0, 2.0, 0.5, 3)returns a 1-D float32 array of shape(3,)where each element is2.0 + exp(0.5) * eps_iforeps ~ N(0, 1), deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.