medium primitives

Batched Sampling via vmap

Why this matters

A common pattern in JAX is generating independent random samples for an entire batch in a single vectorised call. The recipe is always the same: split the base key into one sub-key per batch element, then vmap the per-key sampler over the resulting keys array. This is more composable and JIT-friendly than a Python loop and avoids the subtle mistake of reusing the same key for every element (which would produce identical samples).

Worked mini-example

batch_size = 2, n = 3, seed = 0.

base_key = PRNGKey(0)
keys = split(base_key, 2)   # shape (2, 2) β€” two independent sub-keys
vmap(lambda k: normal(k, (3,)))(keys)  # shape (2, 3)

Each row is an independent draw of 3 Normal(0,1) values.

Common pitfalls

  • Splitting inside vmap: vmap(lambda k: normal(split(k)[0], (n,))) β€” this works but is wasteful; split first so each mapped call gets a unique, already-split key.
  • Shape must be Python int: (int(n),) not (n,) when n is a Python float argument; JAX will raise a tracing error otherwise.
  • Don’t reuse base_key: passing the same key to every call in a loop produces perfectly correlated samples β€” always split before batching.

Problem

Implement vmap_normal_batch(seed, batch_size, n):

  • seed (float) β†’ jax.random.PRNGKey(int(seed))
  • batch_size (float, cast to int) β€” number of independent draws
  • n (float, cast to int) β€” samples per draw

Return a 2-D float32 array of shape (batch_size, n).

Hints

jax vmap random batched

Sign in to attempt this problem and view the solution.