medium primitives

PRNGKey fold_in

Why this matters

jax.random.fold_in(key, data) deterministically derives a sub-key from a base key and a piece of integer “data” (typically a step number or a layer index). Compared to split (which makes N keys all at once), fold_in is convenient when:

  • You don’t know N upfront (you’d need to call split for each new step).
  • You want a stable per-step key independent of how many steps total — the key for step 7 is the same whether you ran 10 or 1000 steps.

Common pattern: training loops use fold_in(base_key, step) for per-step randomness (data shuffling, dropout masks, sample noise). Same step → same noise → reproducible debugging.

Worked mini-example

import jax, jax.numpy as jnp

base_key = jax.random.PRNGKey(0)

# Sub-key for step 7:
k7 = jax.random.fold_in(base_key, 7)
noise_at_step_7 = jax.random.normal(k7, (3,))

# The SAME key is recoverable later:
jax.random.fold_in(base_key, 7) == jax.random.fold_in(base_key, 7)
# → True (deterministic)

# split for parallel allocation:
k1, k2 = jax.random.split(base_key, 2)   # one shot, two keys

Common pitfalls

  • fold_in(key, data) requires an int data. Use int(k) if your loop index is float-typed.
  • Don’t fold_in with the same data to derive different sub-keys — output is deterministic in (key, data).
  • fold_in produces ONE sub-key; for N independent sub-keys at once, use split.
  • Inside a jit-compiled scan/loop, prefer jax.lax.scan with the carry-key pattern over a Python for-loop of fold_in calls (the latter forces unrolling).

Problem

Implement fold_in_per_step(seed, n_steps, n):

  • Create a base key from seed using jax.random.PRNGKey(int(seed)).
  • Loop k from 0 to n_steps - 1.
  • At each step, derive step_key = jax.random.fold_in(base_key, k).
  • Draw n samples from N(0, 1) using step_key.
  • Stack all per-step samples into a 2-D array of shape (n_steps, n).

All three arguments arrive as floats; cast them to int inside the function.

Two illustrative examples (not from the test set):

  • fold_in_per_step(1, 3, 4) returns a (3, 4) array — three rows of four independent standard-normal values, one row per step.
  • fold_in_per_step(99, 5, 1) returns a (5, 1) array — five scalars, one per step, each drawn from N(0, 1) via its own fold_in-derived key.

Hints

jax prng fold-in

Sign in to attempt this problem and view the solution.