medium primitives

nnx.Rngs Container

Why this matters

JAX’s randomness is explicit and stateless: every random op needs a key, and you split keys yourself with jax.random.split. This is precise but verbose — every layer that needs randomness has to manage its own splits.

nnx introduces nnx.Rngs, a tiny stateful container that owns one or more PRNG streams and hands out a fresh key every time you ask. Internally it splits the seed key on each call, so you never write jax.random.split in nnx code. Streams are addressed by name — rngs.params() for parameter init, rngs.dropout() for dropout masks, rngs.noise() for whatever you invent.

Result: a module with one Linear, one Dropout, and a stochastic noise layer needs to receive ONE rngs argument, not three split-and-threaded keys. The rngs object handles the bookkeeping.

API: nnx.Rngs

Construct with a seed (int) or with named seeds:

rngs = nnx.Rngs(0)                 # default stream seeded with 0
rngs = nnx.Rngs(params=0, dropout=1)  # multiple named streams

Draw a key by calling the stream as a method. Each call returns a new key:

rngs = nnx.Rngs(0)
k1 = rngs.params()
k2 = rngs.params()
# k1 != k2 — Rngs split internally on each call.

With one named stream, the default params stream is used implicitly. Asking for an unknown name (e.g., rngs.dropout()) on an Rngs constructed with only a default seed falls back to the default stream’s next key — handy for testing.

Contrast with raw JAX

The same idea, written without nnx:

key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key)        # manual split
a = jax.random.normal(k1, (4,))
b = jax.random.normal(k2, (4,))

With nnx:

rngs = nnx.Rngs(0)
a = jax.random.normal(rngs.params(), (4,))
b = jax.random.normal(rngs.params(), (4,))    # automatically split

Both produce a deterministic result given the same seed; the nnx version is just shorter and composes through deeply nested modules without having to thread keys through every layer.

Worked example

rngs = nnx.Rngs(42)
k1 = rngs.params()
k2 = rngs.params()
a = jax.random.normal(k1, (4,))   # one batch of 4 standard normals
b = jax.random.normal(k2, (4,))   # different 4 standard normals
print(a + b)                      # element-wise sum of the two

Re-running with the same seed gives bit-for-bit identical a + b.

Common pitfalls

  • Calling rngs (the bare object), not rngs.params(). The bare object is not a key.
  • Reusing a key. Don’t call rngs.params() once and reuse the same key for two different random ops — use a fresh call each time.
  • Constructing inside a hot loop. nnx.Rngs(0) resets the stream to seed 0 each time. Build it once outside the loop, or you’ll get the same key over and over.
  • Confusing nnx.Rngs(0) with jax.random.PRNGKey(0). They’re different types: Rngs is the container, PRNGKey is a single key.
  • Casting: seed, k1_offset, k2_offset arrive as floats — cast seed to int for the Rngs and use the offsets as floats.

Problem

Write rngs_two_keys_sum(seed, k1_offset, k2_offset):

  1. Build rngs = nnx.Rngs(int(seed)).
  2. Draw two keys with rngs.params(). They must come from two SEPARATE calls — do not split a single key manually.
  3. Generate a = jax.random.normal(key1, (4,)) + k1_offset and b = jax.random.normal(key2, (4,)) + k2_offset.
  4. Return a + b (a length-4 array).

The two keys are different (Rngs splits internally), so a and b are independent draws — the result is not 2 * jax.random.normal(...) + sum_of_offsets.

Inputs:

  • seed: int (passed as float).
  • k1_offset, k2_offset: floats added to each draw.

Output: 1-D array of length 4.

Hints

flax nnx rngs

Sign in to attempt this problem and view the solution.