We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
PRNGKey fold_in
Why this matters
jax.random.fold_in(key, data) deterministically derives a sub-key from a
base key and a piece of integer “data” (typically a step number or a layer
index). Compared to split (which makes N keys all at once), fold_in is
convenient when:
-
You don’t know N upfront (you’d need to call
splitfor each new step). - You want a stable per-step key independent of how many steps total — the key for step 7 is the same whether you ran 10 or 1000 steps.
Common pattern: training loops use fold_in(base_key, step) for per-step
randomness (data shuffling, dropout masks, sample noise). Same step → same
noise → reproducible debugging.
Worked mini-example
import jax, jax.numpy as jnp
base_key = jax.random.PRNGKey(0)
# Sub-key for step 7:
k7 = jax.random.fold_in(base_key, 7)
noise_at_step_7 = jax.random.normal(k7, (3,))
# The SAME key is recoverable later:
jax.random.fold_in(base_key, 7) == jax.random.fold_in(base_key, 7)
# → True (deterministic)
# split for parallel allocation:
k1, k2 = jax.random.split(base_key, 2) # one shot, two keys
Common pitfalls
-
fold_in(key, data)requires an intdata. Useint(k)if your loop index is float-typed. -
Don’t
fold_inwith the samedatato derive different sub-keys — output is deterministic in (key, data). -
fold_inproduces ONE sub-key; for N independent sub-keys at once, usesplit. -
Inside a
jit-compiled scan/loop, preferjax.lax.scanwith the carry-key pattern over a Python for-loop offold_incalls (the latter forces unrolling).
Problem
Implement fold_in_per_step(seed, n_steps, n):
-
Create a base key from
seedusingjax.random.PRNGKey(int(seed)). -
Loop
kfrom0ton_steps - 1. -
At each step, derive
step_key = jax.random.fold_in(base_key, k). -
Draw
nsamples from N(0, 1) usingstep_key. -
Stack all per-step samples into a 2-D array of shape
(n_steps, n).
All three arguments arrive as floats; cast them to int inside the function.
Two illustrative examples (not from the test set):
-
fold_in_per_step(1, 3, 4)returns a(3, 4)array — three rows of four independent standard-normal values, one row per step. -
fold_in_per_step(99, 5, 1)returns a(5, 1)array — five scalars, one per step, each drawn from N(0, 1) via its ownfold_in-derived key.
Hints
Sign in to attempt this problem and view the solution.