We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Normal Sampling with mean and std
Why this matters
jax.random.normal(key, shape) draws from N(0, 1) β the standard normal.
To sample from an arbitrary N(mu, sigma^2), you apply the standard affine
transform: mu + sigma * z where z ~ N(0, 1). This is the reparameterization
trick in its simplest form, and it underlies variational autoencoders, diffusion
models, and any code that needs gradients through stochastic nodes.
Understanding this two-step pattern β draw standard normal, then shift and scale β is essential before tackling the full reparameterization trick.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
# Standard normal N(0, 1)
z = jax.random.normal(key, (4,))
# Shift to N(mu=5, sigma=2)
mu, sigma = 5.0, 2.0
x = mu + sigma * z
Common pitfalls
-
jax.random.normalhas nomuorsigmaargs β they donβt exist. You must apply the affine transform manually. -
Order matters:
mu + sigma * z, notsigma * (z + mu). -
Shape must be a tuple:
(int(n),)not bareint(n). -
Cast seed to int:
jax.random.PRNGKey(int(seed)).
Problem
Implement normal_mu_sigma(seed, n, mu, sigma) that draws n samples from
N(mu, sigma^2) by first drawing from N(0, 1) and applying the affine
transform.
Both seed and n arrive as floats; cast them to int inside the function.
One illustrative example (not from the test set):
-
normal_mu_sigma(0, 4, 0.0, 1.0)returns a 1-D array of shape(4,)β four samples from N(0, 1), deterministic for seed 0.
Hints
Sign in to attempt this problem and view the solution.