We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multivariate Normal Sampling
Why this matters
The multivariate normal (MVN) is the workhorse of probabilistic models.
Sampling from N(μ, Σ) is central to: VAE encoders with full covariance,
Gaussian processes (prior and posterior draws), Kalman filters, and Bayesian
neural networks. JAX provides jax.random.multivariate_normal(key, mean, cov, shape) which internally performs a Cholesky decomposition of the covariance
and applies the reparameterization z = μ + L ε where L = chol(Σ) and ε ~ N(0,I).
Understanding MVN sampling is a prerequisite for implementing the reparameterization trick for general Gaussian posteriors, not just the diagonal case.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
mean = jnp.array([0.0, 0.0])
cov = jnp.array([[1.0, 0.0],
[0.0, 1.0]]) # isotropic 2-D Gaussian
# Draw 3 samples; output shape (3, 2)
samples = jax.random.multivariate_normal(key, mean, cov, shape=(3,))
# → float32 array of shape (3, 2)
# Correlated covariance
cov_corr = jnp.array([[1.0, 0.8], [0.8, 1.0]])
corr_samples = jax.random.multivariate_normal(key, mean, cov_corr, shape=(5,))
Common pitfalls
-
covis the FULL covariance matrix (d, d), not a standard deviation vector. Cholesky decomposition requires a positive semi-definite matrix. -
Shape is the SAMPLE shape:
shape=(n,)means n draws; output is(n, d). -
nmay arrive as a float: cast tointbefore building shape tuple. - Non-PSD covariance raises: ensure diagonal entries are positive and the matrix is symmetric; off-diagonal entries control correlations.
Problem
Implement multivariate_normal_sample(seed, mean, cov, n) that draws n
samples from N(mean, cov).
mean is a 1-D JAX array of shape (d,). cov is a 2-D JAX array of
shape (d, d) — the full covariance matrix. Return a 2-D float32 array of
shape (n, d).
One illustrative example (not from the test set):
-
multivariate_normal_sample(0, jnp.zeros(2), jnp.eye(2), 3.0)returns a float32 array of shape(3, 2), deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.