We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
PRNGKey and Split
Why this matters
JAX takes a different approach to randomness than NumPy or PyTorch. Instead
of a global PRNG state, every random function takes an EXPLICIT KEY as its
first argument. This is the price (and benefit) of pure functions: a function
with random behavior is impure if it secretly mutates global state, but PURE
if you pass the key in. The key is a small JAX array (typically a uint32[2]).
The PRNGKey(seed) function creates a key from a Python int. To get N
independent samples, you split(key, N) to derive N sub-keys
deterministically, then call your random op with each sub-key. NEVER reuse
a key for two random ops โ that produces correlated samples.
Worked mini-example
import jax, jax.numpy as jnp
key = jax.random.PRNGKey(0)
# โ Reusing the key:
x1 = jax.random.normal(key, (3,))
x2 = jax.random.normal(key, (3,)) # IDENTICAL to x1!
# โ
Splitting:
k1, k2 = jax.random.split(key, 2)
x1 = jax.random.normal(k1, (3,))
x2 = jax.random.normal(k2, (3,)) # independent
Common pitfalls
- Reusing the same key: produces identical outputs, not independent samples. Always split.
-
split(key, N)returns N keys (a 2-D array of shape(N, 2)). Unpack withk1, k2 = split(key, 2)or indexkeys[i]. -
PRNGKey requires an int: cast Python floats with
int(seed)first. -
Forgetting the shape arg:
jax.random.normal(key)returns a SCALAR. Pass(n,)for a vector.
Problem
Implement two_random_draws(seed, n) that creates a base key from seed,
splits it into two sub-keys, draws an independent N(0, 1) sample of length
n with each sub-key, and stacks the results into a 2-D array of shape
(2, n).
Both seed and n arrive as floats; cast them to int inside the function.
One illustrative example (not from the test set):
-
two_random_draws(1, 3)returns a 2-D array of shape(2, 3)โ two independently-drawn length-3 standard normal vectors, deterministic for seed 1.
Hints
Sign in to attempt this problem and view the solution.