We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement Dropout with RNG Threading
Why this matters
Dropout looks trivially simple in pseudocode: “drop each element with
probability p“. But Flax’s strict purity contract makes the
implementation interesting:
- You can’t access a global RNG — every random draw needs an explicit PRNG key.
-
Inside
__call__, where do you get the key from? Flax providesself.make_rng("dropout")which threads a key from therngs={"dropout": key}dict you pass toapply. -
The forward pass branches on
deterministic: at eval time it’s a no-op; at train time it samples a mask AND scales by1/keep_probso expected activations don’t change.
Reimplementing Dropout makes you understand all three. Once it clicks, every Flax layer with runtime randomness (StochasticDepth, NoiseInjection, sampling decoders) follows the same pattern.
Inverted dropout: scale by 1/keep_prob
The standard “inverted dropout” used everywhere in modern DL:
keep_prob = 1 - rate
mask = bernoulli(p=keep_prob, shape=x.shape) # 1 = kept, 0 = dropped
y = where(mask, x / keep_prob, 0)
Why scale up? At train time the EXPECTED value of y matches x:
E[y] = keep_prob * (x / keep_prob) + (1 - keep_prob) * 0 = x
So at eval time you can use y = x (the identity) without any further
rescaling — the network was trained to expect inputs at this scale.
self.make_rng(“dropout”)
Inside @nn.compact, you obtain a fresh PRNG key for a named stream:
rng = self.make_rng("dropout")
Flax derives this from the "dropout" key passed to init/apply and
a path-dependent fold-in, so different layers get different keys
automatically. You don’t need to split or thread keys manually.
Calling make_rng("dropout") from __call__:
-
During
init: works only if you passeddropoutin the init RNGs dict ANDdeterministic=False. (Ifdeterministic=True, the branch isn’t executed andmake_rngisn’t called.) -
During
apply: works only if you passedrngs={"dropout": key}.
Worked implementation
class MyDropout(nn.Module):
rate: float
@nn.compact
def __call__(self, x, deterministic: bool):
if deterministic or self.rate == 0.0:
return x
keep_prob = 1.0 - self.rate
rng = self.make_rng("dropout")
mask = jax.random.bernoulli(rng, p=keep_prob, shape=x.shape)
return jnp.where(mask, x / keep_prob, 0.0)
Note the early-return for deterministic=True and rate==0.0 — both
avoid the unnecessary make_rng call.
Common pitfalls
-
x * maskthen divide later: works but bookkeeping is messier.where(mask, x / keep_prob, 0)is the canonical form. -
Calling
make_rng("dropout")outside__call__: only works inside__call__(or a method called from it) because that’s when Flax has the rngs dict in scope. -
Forgetting
rngs={"dropout": key}in apply:make_rngraises with “no key found fordropoutcollection”. - Same key reused across calls: dropout pattern repeats — defeats randomization. In real loops, split a fresh key per step.
Problem
Implement MyDropout(rate):
-
If
deterministic=TrueORrate == 0.0, returnx. -
Else:
-
keep_prob = 1 - rate. -
rng = self.make_rng("dropout"). -
mask = jax.random.bernoulli(rng, p=keep_prob, shape=x.shape). -
Return
jnp.where(mask, x / keep_prob, 0.0).
-
The function does the full init + apply:
-
Init with
{"params": PRNGKey(seed), "dropout": PRNGKey(seed+1)},deterministic=False. -
Apply with the test’s
deterministicflag andrngs={"dropout": PRNGKey(seed+1)}.
Inputs:
-
seed: float (cast to int). -
x: 1-D JAX array. -
rate: float. -
deterministic: float (0.0 or 1.0).
Output: 1-D array (same shape as x).
Hints
Sign in to attempt this problem and view the solution.