medium primitives

Multiple PRNG Streams (params, dropout)

Why this matters

JAX is intentionally explicit about randomness: every random op needs its own PRNGKey, and every Flax init and apply call must be told where the keys come from. Most Modules use a single key for parameter init. But as soon as you add nn.Dropout, layer-wise stochastic depth, masking, etc., you need multiple separate streams — one for params, one (or more) for runtime randomness.

Flax’s “RNG collections” mechanism is the answer. Instead of passing a single key to init/apply, you pass a dict mapping stream names to keys.

The two canonical streams

  • "params" — used during init to draw initial parameter values.
  • "dropout" — used during apply whenever a layer needs runtime randomness (dropout mask, noise injection, etc.).

init always needs at least the "params" key. If your module has Dropout, init ALSO needs a "dropout" key (so the trace through nn.Dropout works).

apply needs whatever streams the forward path consumes — typically only "dropout" if the module has Dropout in train mode; nothing if deterministic=True.

init signature with multiple streams

init_key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)

params = model.init(
    {"params": init_key, "dropout": drop_key},  # NOTE: dict, not single key
    x,
    deterministic=False,                         # any forward kwargs
)

The dict form replaces the single-key form. The single-key form is shorthand for {"params": key} — it works only when the model has no other RNG needs.

apply signature with rngs=

out = model.apply(
    params,
    x,
    deterministic=False,
    rngs={"dropout": jax.random.PRNGKey(7)},
)

Note the kwarg name is rngs= (plural). You only pass streams the forward path will consume. If deterministic=True is wired correctly, no dropout randomness is needed and you can omit rngs=.

Worked example

class DropoutModel(nn.Module):
    drop_rate: float

    @nn.compact
    def __call__(self, x, deterministic):
        x = nn.Dense(features=x.shape[-1])(x)
        x = nn.Dropout(rate=self.drop_rate, deterministic=deterministic)(x)
        return x

model = DropoutModel(drop_rate=0.5)

# Init: needs both streams
params = model.init(
    {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)},
    x,
    deterministic=False,
)

# Apply (training, dropout active):
y_train = model.apply(params, x, deterministic=False,
                      rngs={"dropout": jax.random.PRNGKey(7)})

# Apply (eval, dropout off):
y_eval = model.apply(params, x, deterministic=True)
# No rngs= needed since dropout doesn't sample.

Key splitting in real training loops

In real code, you split a top-level key per step so dropout uses fresh randomness:

rng = jax.random.PRNGKey(0)
for step in range(num_steps):
    rng, drop_rng = jax.random.split(rng)
    out = model.apply(params, batch, deterministic=False,
                      rngs={"dropout": drop_rng})

Never reuse the same key — that defeats randomness and can introduce subtle bugs.

Common pitfalls

  • Single key with Dropout in init: model.init(key, x) (single key) when the model contains Dropout raises a missing-collection error. Use the dict form.
  • Forgetting rngs= in apply: if deterministic=False and you don’t pass rngs={"dropout": ...}, you get a missing-key error.
  • Reusing the same dropout key: every call uses the same mask — the stochasticity is fake. Always split.
  • Wrong arg name: it’s rngs= (plural), not rng=.

Problem

Build DropoutModel(drop_rate):

  1. Dense(features=x.shape[-1])
  2. Dropout(rate=drop_rate, deterministic=deterministic)

The function:

  • Takes seed, x, dropout_seed, drop_rate, deterministic (all floats).
  • Casts deterministic to bool via bool(deterministic >= 0.5).
  • Inits with both "params" and "dropout" streams using seed and dropout_seed respectively. Init with deterministic=False so the trace through Dropout sees the dropout collection.
  • Applies with the given deterministic flag and rngs={"dropout": drop_key}.

Inputs:

  • seed: float (cast to int) — params init seed.
  • x: 1-D JAX array.
  • dropout_seed: float (cast to int) — dropout key seed.
  • drop_rate: float in [0, 1].
  • deterministic: float (0.0 or 1.0).

Output: 1-D array, same length as x.

Hints

flax rng dropout

Sign in to attempt this problem and view the solution.