We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multiple PRNG Streams (params, dropout)
Why this matters
JAX is intentionally explicit about randomness: every random op needs its own
PRNGKey, and every Flax init and apply call must be told where the keys
come from. Most Modules use a single key for parameter init. But as soon as
you add nn.Dropout, layer-wise stochastic depth, masking, etc., you need
multiple separate streams — one for params, one (or more) for runtime
randomness.
Flax’s “RNG collections” mechanism is the answer. Instead of passing a single
key to init/apply, you pass a dict mapping stream names to keys.
The two canonical streams
-
"params"— used duringinitto draw initial parameter values. -
"dropout"— used duringapplywhenever a layer needs runtime randomness (dropout mask, noise injection, etc.).
init always needs at least the "params" key. If your module has Dropout,
init ALSO needs a "dropout" key (so the trace through nn.Dropout works).
apply needs whatever streams the forward path consumes — typically only
"dropout" if the module has Dropout in train mode; nothing if deterministic=True.
init signature with multiple streams
init_key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)
params = model.init(
{"params": init_key, "dropout": drop_key}, # NOTE: dict, not single key
x,
deterministic=False, # any forward kwargs
)
The dict form replaces the single-key form. The single-key form is shorthand
for {"params": key} — it works only when the model has no other RNG needs.
apply signature with rngs=
out = model.apply(
params,
x,
deterministic=False,
rngs={"dropout": jax.random.PRNGKey(7)},
)
Note the kwarg name is rngs= (plural). You only pass streams the forward
path will consume. If deterministic=True is wired correctly, no dropout
randomness is needed and you can omit rngs=.
Worked example
class DropoutModel(nn.Module):
drop_rate: float
@nn.compact
def __call__(self, x, deterministic):
x = nn.Dense(features=x.shape[-1])(x)
x = nn.Dropout(rate=self.drop_rate, deterministic=deterministic)(x)
return x
model = DropoutModel(drop_rate=0.5)
# Init: needs both streams
params = model.init(
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)},
x,
deterministic=False,
)
# Apply (training, dropout active):
y_train = model.apply(params, x, deterministic=False,
rngs={"dropout": jax.random.PRNGKey(7)})
# Apply (eval, dropout off):
y_eval = model.apply(params, x, deterministic=True)
# No rngs= needed since dropout doesn't sample.
Key splitting in real training loops
In real code, you split a top-level key per step so dropout uses fresh randomness:
rng = jax.random.PRNGKey(0)
for step in range(num_steps):
rng, drop_rng = jax.random.split(rng)
out = model.apply(params, batch, deterministic=False,
rngs={"dropout": drop_rng})
Never reuse the same key — that defeats randomness and can introduce subtle bugs.
Common pitfalls
-
Single key with Dropout in init:
model.init(key, x)(single key) when the model contains Dropout raises a missing-collection error. Use the dict form. -
Forgetting
rngs=in apply: ifdeterministic=Falseand you don’t passrngs={"dropout": ...}, you get a missing-key error. - Reusing the same dropout key: every call uses the same mask — the stochasticity is fake. Always split.
-
Wrong arg name: it’s
rngs=(plural), notrng=.
Problem
Build DropoutModel(drop_rate):
-
Dense(features=x.shape[-1]) -
Dropout(rate=drop_rate, deterministic=deterministic)
The function:
-
Takes
seed,x,dropout_seed,drop_rate,deterministic(all floats). -
Casts
deterministicto bool viabool(deterministic >= 0.5). -
Inits with both
"params"and"dropout"streams usingseedanddropout_seedrespectively. Init withdeterministic=Falseso the trace through Dropout sees the dropout collection. -
Applies with the given
deterministicflag andrngs={"dropout": drop_key}.
Inputs:
-
seed: float (cast to int) — params init seed. -
x: 1-D JAX array. -
dropout_seed: float (cast to int) — dropout key seed. -
drop_rate: float in [0, 1]. -
deterministic: float (0.0 or 1.0).
Output: 1-D array, same length as x.
Hints
Sign in to attempt this problem and view the solution.