We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Batched Sampling via vmap
Why this matters
A common pattern in JAX is generating independent random samples for an entire batch in a single vectorised call. The recipe is always the same: split the base key into one sub-key per batch element, then vmap the per-key sampler over the resulting keys array. This is more composable and JIT-friendly than a Python loop and avoids the subtle mistake of reusing the same key for every element (which would produce identical samples).
Worked mini-example
batch_size = 2, n = 3, seed = 0.
base_key = PRNGKey(0)
keys = split(base_key, 2) # shape (2, 2) β two independent sub-keys
vmap(lambda k: normal(k, (3,)))(keys) # shape (2, 3)
Each row is an independent draw of 3 Normal(0,1) values.
Common pitfalls
-
Splitting inside vmap:
vmap(lambda k: normal(split(k)[0], (n,)))β this works but is wasteful; split first so each mapped call gets a unique, already-split key. -
Shape must be Python int:
(int(n),)not(n,)whennis a Python float argument; JAX will raise a tracing error otherwise. - Donβt reuse base_key: passing the same key to every call in a loop produces perfectly correlated samples β always split before batching.
Problem
Implement vmap_normal_batch(seed, batch_size, n):
-
seed(float) βjax.random.PRNGKey(int(seed)) -
batch_size(float, cast to int) β number of independent draws -
n(float, cast to int) β samples per draw
Return a 2-D float32 array of shape (batch_size, n).
Hints
Sign in to attempt this problem and view the solution.