easy primitives

Normal Sampling with mean and std

Why this matters

jax.random.normal(key, shape) draws from N(0, 1) β€” the standard normal. To sample from an arbitrary N(mu, sigma^2), you apply the standard affine transform: mu + sigma * z where z ~ N(0, 1). This is the reparameterization trick in its simplest form, and it underlies variational autoencoders, diffusion models, and any code that needs gradients through stochastic nodes.

Understanding this two-step pattern β€” draw standard normal, then shift and scale β€” is essential before tackling the full reparameterization trick.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)

# Standard normal N(0, 1)
z = jax.random.normal(key, (4,))

# Shift to N(mu=5, sigma=2)
mu, sigma = 5.0, 2.0
x = mu + sigma * z

Common pitfalls

  • jax.random.normal has no mu or sigma args β€” they don’t exist. You must apply the affine transform manually.
  • Order matters: mu + sigma * z, not sigma * (z + mu).
  • Shape must be a tuple: (int(n),) not bare int(n).
  • Cast seed to int: jax.random.PRNGKey(int(seed)).

Problem

Implement normal_mu_sigma(seed, n, mu, sigma) that draws n samples from N(mu, sigma^2) by first drawing from N(0, 1) and applying the affine transform.

Both seed and n arrive as floats; cast them to int inside the function.

One illustrative example (not from the test set):

  • normal_mu_sigma(0, 4, 0.0, 1.0) returns a 1-D array of shape (4,) β€” four samples from N(0, 1), deterministic for seed 0.

Hints

jax random normal

Sign in to attempt this problem and view the solution.