We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Reparameterization Gradient
Why this matters
The reparameterisation trick (Kingma & Welling 2013) enables
differentiating through stochastic nodes by re-writing
z ~ N(μ, σ²) as z = μ + σ·ε, ε ~ N(0,1). The randomness is moved
into ε — a parameter-free distribution — so jax.grad can flow through
z unobstructed. This is the backbone of VAE training, flow-based models,
and any differentiable Monte Carlo estimator. Compared to REINFORCE
(score-function estimator), the reparam gradient has much lower variance
because it computes exact gradients through the sample path.
Worked mini-example
Objective: E[z²] where z ~ N(μ, σ²), σ = exp(logσ). Reparam: z = μ + σ·ε, ε ~ N(0,1). Loss = mean(z²) = mean((μ + σ·ε)²). ∂loss/∂μ = 2·mean(μ + σ·ε) ≈ 2μ (analytically, since E[ε] = 0). ∂loss/∂logσ = 2·mean((μ + σ·ε)·σ·ε) via chain rule. Analytically: E[z²] = μ² + σ², so ∂/∂μ = 2μ, ∂/∂log_σ = 2σ².
Common pitfalls
-
Draw ε OUTSIDE the loss function: ε must be a fixed array closed over
by
loss(mu, log_sigma). If you draw ε insideloss, JAX cannot propagate gradients through it (random ops are not differentiable). -
argnums=(0, 1):jax.grad(loss)defaults to argnums=0 (grad w.r.t. first arg only). Useargnums=(0, 1)to get both ∂/∂μ and ∂/∂log_σ. -
Chain rule on log_σ: ∂/∂logσ includes the Jacobian of σ = exp(logσ),
so it is σ² · (something), not σ · (something). JAX handles this
automatically via
jax.grad.
Problem
Implement reparam_grad(seed, mu, log_sigma) that computes the gradient
of E[z²] w.r.t. (mu, log_sigma) using the reparameterisation trick with
1 000 samples.
-
seed(float) →jax.random.PRNGKey(int(seed)) -
mu,log_sigma— scalar floats
Return a 1-D float32 array of shape (2,) — [grad_mu, grad_log_sigma].
Hints
Sign in to attempt this problem and view the solution.