medium primitives

Multivariate Normal Sampling

Why this matters

The multivariate normal (MVN) is the workhorse of probabilistic models. Sampling from N(μ, Σ) is central to: VAE encoders with full covariance, Gaussian processes (prior and posterior draws), Kalman filters, and Bayesian neural networks. JAX provides jax.random.multivariate_normal(key, mean, cov, shape) which internally performs a Cholesky decomposition of the covariance and applies the reparameterization z = μ + L ε where L = chol(Σ) and ε ~ N(0,I).

Understanding MVN sampling is a prerequisite for implementing the reparameterization trick for general Gaussian posteriors, not just the diagonal case.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
mean = jnp.array([0.0, 0.0])
cov  = jnp.array([[1.0, 0.0],
                   [0.0, 1.0]])  # isotropic 2-D Gaussian

# Draw 3 samples; output shape (3, 2)
samples = jax.random.multivariate_normal(key, mean, cov, shape=(3,))
# → float32 array of shape (3, 2)

# Correlated covariance
cov_corr = jnp.array([[1.0, 0.8], [0.8, 1.0]])
corr_samples = jax.random.multivariate_normal(key, mean, cov_corr, shape=(5,))

Common pitfalls

  • cov is the FULL covariance matrix (d, d), not a standard deviation vector. Cholesky decomposition requires a positive semi-definite matrix.
  • Shape is the SAMPLE shape: shape=(n,) means n draws; output is (n, d).
  • n may arrive as a float: cast to int before building shape tuple.
  • Non-PSD covariance raises: ensure diagonal entries are positive and the matrix is symmetric; off-diagonal entries control correlations.

Problem

Implement multivariate_normal_sample(seed, mean, cov, n) that draws n samples from N(mean, cov).

mean is a 1-D JAX array of shape (d,). cov is a 2-D JAX array of shape (d, d) — the full covariance matrix. Return a 2-D float32 array of shape (n, d).

One illustrative example (not from the test set):

  • multivariate_normal_sample(0, jnp.zeros(2), jnp.eye(2), 3.0) returns a float32 array of shape (3, 2), deterministic for seed 0.

Hints

jax random multivariate-normal

Sign in to attempt this problem and view the solution.