easy primitives

PRNGKey and Split

Why this matters

JAX takes a different approach to randomness than NumPy or PyTorch. Instead of a global PRNG state, every random function takes an EXPLICIT KEY as its first argument. This is the price (and benefit) of pure functions: a function with random behavior is impure if it secretly mutates global state, but PURE if you pass the key in. The key is a small JAX array (typically a uint32[2]).

The PRNGKey(seed) function creates a key from a Python int. To get N independent samples, you split(key, N) to derive N sub-keys deterministically, then call your random op with each sub-key. NEVER reuse a key for two random ops โ€” that produces correlated samples.

Worked mini-example

import jax, jax.numpy as jnp

key = jax.random.PRNGKey(0)
# โŒ Reusing the key:
x1 = jax.random.normal(key, (3,))
x2 = jax.random.normal(key, (3,))   # IDENTICAL to x1!

# โœ… Splitting:
k1, k2 = jax.random.split(key, 2)
x1 = jax.random.normal(k1, (3,))
x2 = jax.random.normal(k2, (3,))   # independent

Common pitfalls

  • Reusing the same key: produces identical outputs, not independent samples. Always split.
  • split(key, N) returns N keys (a 2-D array of shape (N, 2)). Unpack with k1, k2 = split(key, 2) or index keys[i].
  • PRNGKey requires an int: cast Python floats with int(seed) first.
  • Forgetting the shape arg: jax.random.normal(key) returns a SCALAR. Pass (n,) for a vector.

Problem

Implement two_random_draws(seed, n) that creates a base key from seed, splits it into two sub-keys, draws an independent N(0, 1) sample of length n with each sub-key, and stacks the results into a 2-D array of shape (2, n).

Both seed and n arrive as floats; cast them to int inside the function.

One illustrative example (not from the test set):

  • two_random_draws(1, 3) returns a 2-D array of shape (2, 3) โ€” two independently-drawn length-3 standard normal vectors, deterministic for seed 1.

Hints

jax prng split

Sign in to attempt this problem and view the solution.