We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nnx.Rngs Container
Why this matters
JAX’s randomness is explicit and stateless: every random op needs a key,
and you split keys yourself with jax.random.split. This is precise but
verbose — every layer that needs randomness has to manage its own splits.
nnx introduces nnx.Rngs, a tiny stateful container that owns one or more
PRNG streams and hands out a fresh key every time you ask. Internally it
splits the seed key on each call, so you never write jax.random.split in
nnx code. Streams are addressed by name — rngs.params() for parameter
init, rngs.dropout() for dropout masks, rngs.noise() for whatever you
invent.
Result: a module with one Linear, one Dropout, and a stochastic noise layer
needs to receive ONE rngs argument, not three split-and-threaded keys.
The rngs object handles the bookkeeping.
API: nnx.Rngs
Construct with a seed (int) or with named seeds:
rngs = nnx.Rngs(0) # default stream seeded with 0
rngs = nnx.Rngs(params=0, dropout=1) # multiple named streams
Draw a key by calling the stream as a method. Each call returns a new key:
rngs = nnx.Rngs(0)
k1 = rngs.params()
k2 = rngs.params()
# k1 != k2 — Rngs split internally on each call.
With one named stream, the default params stream is used implicitly. Asking
for an unknown name (e.g., rngs.dropout()) on an Rngs constructed with only
a default seed falls back to the default stream’s next key — handy for testing.
Contrast with raw JAX
The same idea, written without nnx:
key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key) # manual split
a = jax.random.normal(k1, (4,))
b = jax.random.normal(k2, (4,))
With nnx:
rngs = nnx.Rngs(0)
a = jax.random.normal(rngs.params(), (4,))
b = jax.random.normal(rngs.params(), (4,)) # automatically split
Both produce a deterministic result given the same seed; the nnx version is just shorter and composes through deeply nested modules without having to thread keys through every layer.
Worked example
rngs = nnx.Rngs(42)
k1 = rngs.params()
k2 = rngs.params()
a = jax.random.normal(k1, (4,)) # one batch of 4 standard normals
b = jax.random.normal(k2, (4,)) # different 4 standard normals
print(a + b) # element-wise sum of the two
Re-running with the same seed gives bit-for-bit identical a + b.
Common pitfalls
-
Calling
rngs(the bare object), notrngs.params(). The bare object is not a key. -
Reusing a key. Don’t call
rngs.params()once and reuse the same key for two different random ops — use a fresh call each time. -
Constructing inside a hot loop.
nnx.Rngs(0)resets the stream to seed 0 each time. Build it once outside the loop, or you’ll get the same key over and over. -
Confusing
nnx.Rngs(0)withjax.random.PRNGKey(0). They’re different types:Rngsis the container,PRNGKeyis a single key. -
Casting:
seed,k1_offset,k2_offsetarrive as floats — castseedtointfor the Rngs and use the offsets as floats.
Problem
Write rngs_two_keys_sum(seed, k1_offset, k2_offset):
-
Build
rngs = nnx.Rngs(int(seed)). -
Draw two keys with
rngs.params(). They must come from two SEPARATE calls — do not split a single key manually. -
Generate
a = jax.random.normal(key1, (4,)) + k1_offsetandb = jax.random.normal(key2, (4,)) + k2_offset. -
Return
a + b(a length-4 array).
The two keys are different (Rngs splits internally), so a and b are
independent draws — the result is not 2 * jax.random.normal(...) + sum_of_offsets.
Inputs:
-
seed: int (passed as float). -
k1_offset,k2_offset: floats added to each draw.
Output: 1-D array of length 4.
Hints
Sign in to attempt this problem and view the solution.