medium primitives

Deterministic Batch via vmap+split

Why this matters

When you need a batch of N independent random samples (e.g., a dropout mask per layer, a per-example noise tensor, an ensemble of stochastic policies), the canonical JAX pattern is vmap over a split key array:

  1. keys = jax.random.split(base, N) → an array of N keys.
  2. vmap(draw)(keys)draw(keys[i]) for each i, run in parallel under XLA.

This pattern is more efficient than a Python for-loop (which forces unrolling under jit) AND keeps each per-sample computation reproducible (same key → same output). It’s the “batch” sibling of the per-step fold_in pattern from the previous problem.

Worked mini-example

import jax, jax.numpy as jnp

base = jax.random.PRNGKey(0)
batch_size = 4

keys = jax.random.split(base, batch_size)   # shape (4, 2)
def draw(k):
    return jax.random.normal(k, (3,))       # one length-3 sample per key

batch = jax.vmap(draw)(keys)
# batch.shape == (4, 3) — each row is an independent N(0,1) sample

Common pitfalls

  • Calling jax.random.normal(base, (batch_size, n)) gives ONE sample of that shape — NOT a batch of independent samples. (Each row would still be drawn from the same draw, just reshaped.)
  • Forgetting to vmap: a Python list comprehension over keys would work but unrolls inside jit.
  • Wrong n cast: jax.random.normal(k, (n,)) requires n as a Python int. Cast with int(n) outside the vmap’d draw if needed (the closure captures it).
  • Splitting WHILE vmapping: don’t put jax.random.split inside the vmapped function — split first, then vmap.

Problem

Implement deterministic_batch(seed, batch_size, n):

  • Create a base key from seed using jax.random.PRNGKey(int(seed)).
  • Split it into batch_size sub-keys with jax.random.split.
  • Define a per-sample draw function that draws n values from N(0, 1).
  • Apply the draw function over all sub-keys using jax.vmap.
  • Return a 2-D array of shape (batch_size, n).

All three arguments arrive as floats; cast them to int inside the function.

Two illustrative examples (not from the test set):

  • deterministic_batch(1, 3, 4) returns a (3, 4) array — three rows of four independent standard-normal values, one row per batch element.
  • deterministic_batch(99, 5, 1) returns a (5, 1) array — five scalars, each drawn from N(0, 1) via its own split-derived key.

Hints

jax prng vmap split

Sign in to attempt this problem and view the solution.