We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Deterministic Batch via vmap+split
Why this matters
When you need a batch of N independent random samples (e.g., a dropout mask
per layer, a per-example noise tensor, an ensemble of stochastic policies),
the canonical JAX pattern is vmap over a split key array:
-
keys = jax.random.split(base, N)→ an array of N keys. -
vmap(draw)(keys)→draw(keys[i])for each i, run in parallel under XLA.
This pattern is more efficient than a Python for-loop (which forces unrolling
under jit) AND keeps each per-sample computation reproducible (same key →
same output). It’s the “batch” sibling of the per-step fold_in pattern from
the previous problem.
Worked mini-example
import jax, jax.numpy as jnp
base = jax.random.PRNGKey(0)
batch_size = 4
keys = jax.random.split(base, batch_size) # shape (4, 2)
def draw(k):
return jax.random.normal(k, (3,)) # one length-3 sample per key
batch = jax.vmap(draw)(keys)
# batch.shape == (4, 3) — each row is an independent N(0,1) sample
Common pitfalls
-
Calling
jax.random.normal(base, (batch_size, n))gives ONE sample of that shape — NOT a batch of independent samples. (Each row would still be drawn from the same draw, just reshaped.) -
Forgetting to vmap: a Python list comprehension over
keyswould work but unrolls insidejit. -
Wrong
ncast:jax.random.normal(k, (n,))requiresnas a Python int. Cast withint(n)outside the vmap’ddrawif needed (the closure captures it). -
Splitting WHILE vmapping: don’t put
jax.random.splitinside the vmapped function — split first, then vmap.
Problem
Implement deterministic_batch(seed, batch_size, n):
-
Create a base key from
seedusingjax.random.PRNGKey(int(seed)). -
Split it into
batch_sizesub-keys withjax.random.split. -
Define a per-sample draw function that draws
nvalues from N(0, 1). -
Apply the draw function over all sub-keys using
jax.vmap. -
Return a 2-D array of shape
(batch_size, n).
All three arguments arrive as floats; cast them to int inside the function.
Two illustrative examples (not from the test set):
-
deterministic_batch(1, 3, 4)returns a(3, 4)array — three rows of four independent standard-normal values, one row per batch element. -
deterministic_batch(99, 5, 1)returns a(5, 1)array — five scalars, each drawn from N(0, 1) via its own split-derived key.
Hints
Sign in to attempt this problem and view the solution.