medium primitives

Reparameterization Trick: Gaussian

Why this matters

The reparameterization trick is what makes VAE encoders trainable. A naïve sample z ~ N(mu, sigma^2) is a stochastic node that blocks gradient flow back through mu and sigma. The reparam trick rewrites the sample as z = mu + sigma * eps, where eps ~ N(0, 1) is the only stochastic node. Because eps carries no parameters, gradients flow cleanly through mu and sigma — enabling ELBO optimization via standard backprop.

This exact pattern appears in every VAE encoder: the network outputs mu and log_sigma (unconstrained), you exponentiate to recover sigma, then apply the affine transform to a noise sample.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
mu, log_sigma = 2.0, 0.5

eps = jax.random.normal(key, (1,))   # N(0,1) — the stochastic node
sigma = jnp.exp(log_sigma)           # recover sigma from unconstrained param
z = mu + sigma * eps                 # differentiable w.r.t. mu and log_sigma
# z is a sample from N(2.0, exp(0.5)^2) ≈ N(2.0, 1.65)

Common pitfalls

  • Parameterize as log_sigma, not sigma: log_sigma is unconstrained (any real number), while sigma must be positive. Recover via exp(log_sigma).
  • The noise must be separate: draw eps ~ N(0, 1), then scale. Do NOT draw from N(mu, sigma^2) directly — that blocks gradients.
  • Shape is (int(n),): n may arrive as a float; cast before use.
  • PRNGKey needs a Python int: jax.random.PRNGKey(int(seed)).

Problem

Implement reparam_normal(seed, mu, log_sigma, n) that returns n samples from N(mu, exp(log_sigma)^2) using the reparameterization trick.

One illustrative example (not from the test set):

  • reparam_normal(0, 2.0, 0.5, 3) returns a 1-D float32 array of shape (3,) where each element is 2.0 + exp(0.5) * eps_i for eps ~ N(0, 1), deterministic for seed 0.

Hints

jax reparam vae

Sign in to attempt this problem and view the solution.