medium primitives

Reparameterization Gradient

Why this matters

The reparameterisation trick (Kingma & Welling 2013) enables differentiating through stochastic nodes by re-writing z ~ N(μ, σ²) as z = μ + σ·ε, ε ~ N(0,1). The randomness is moved into ε — a parameter-free distribution — so jax.grad can flow through z unobstructed. This is the backbone of VAE training, flow-based models, and any differentiable Monte Carlo estimator. Compared to REINFORCE (score-function estimator), the reparam gradient has much lower variance because it computes exact gradients through the sample path.

Worked mini-example

Objective: E[z²] where z ~ N(μ, σ²), σ = exp(logσ). Reparam: z = μ + σ·ε, ε ~ N(0,1). Loss = mean(z²) = mean((μ + σ·ε)²). ∂loss/∂μ = 2·mean(μ + σ·ε) ≈ 2μ (analytically, since E[ε] = 0). ∂loss/∂logσ = 2·mean((μ + σ·ε)·σ·ε) via chain rule. Analytically: E[z²] = μ² + σ², so ∂/∂μ = 2μ, ∂/∂log_σ = 2σ².

Common pitfalls

  • Draw ε OUTSIDE the loss function: ε must be a fixed array closed over by loss(mu, log_sigma). If you draw ε inside loss, JAX cannot propagate gradients through it (random ops are not differentiable).
  • argnums=(0, 1): jax.grad(loss) defaults to argnums=0 (grad w.r.t. first arg only). Use argnums=(0, 1) to get both ∂/∂μ and ∂/∂log_σ.
  • Chain rule on log_σ: ∂/∂logσ includes the Jacobian of σ = exp(logσ), so it is σ² · (something), not σ · (something). JAX handles this automatically via jax.grad.

Problem

Implement reparam_grad(seed, mu, log_sigma) that computes the gradient of E[z²] w.r.t. (mu, log_sigma) using the reparameterisation trick with 1 000 samples.

  • seed (float) → jax.random.PRNGKey(int(seed))
  • mu, log_sigma — scalar floats

Return a 1-D float32 array of shape (2,)[grad_mu, grad_log_sigma].

Hints

jax reparam-grad vae

Sign in to attempt this problem and view the solution.