We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement Dropout
Why this matters
Dropout is the canonical example of a layer that needs three things at
once: a hyperparameter (the drop rate), a fresh PRNG key per call (so
every forward gets a different mask), and a train/eval flag (eval is
a no-op). In Linen this required passing rngs={"dropout": key} at
every apply call site — easy to forget, and a frequent source of
‘eval looks fine, train doesn’t’ bugs.
nnx threads keys through an nnx.Rngs container that the module owns
as an attribute. Each call to self.rngs.dropout() returns a fresh
key (the container splits internally) — same as parameter init, just a
different stream name. No threading at the call site.
Inverted dropout
Standard practice is inverted dropout: scale the kept activations UP at training time so that the expected output equals the input. Then at eval, the layer is a no-op:
keep_prob = 1.0 - rate
mask = jax.random.bernoulli(key, p=keep_prob, shape=x.shape)
out = jnp.where(mask, x / keep_prob, 0.0)
Without the /keep_prob rescale, the network would see different
activation magnitudes between train and eval; inverted dropout makes
them match in expectation, so the eval branch can simply return x.
API: rngs as a member
class MyDropout(nnx.Module):
def __init__(self, rate, rngs):
self.rate = rate
self.rngs = rngs # save the container
def __call__(self, x, deterministic):
if deterministic or self.rate <= 0.0:
return x
keep_prob = 1.0 - self.rate
key = self.rngs.dropout() # fresh key per call
mask = jax.random.bernoulli(key, p=keep_prob, shape=x.shape)
return jnp.where(mask, x / keep_prob, 0.0)
rngs.dropout() is a method on the rngs container that returns a
new key every time it’s called. Internally it tracks a counter and
splits a master key — you don’t have to manage that yourself.
The dropout name is just a stream name. nnx.Rngs(0) registers a
params stream by default; calls to .dropout() lazily create a
dropout stream the first time you ask. Later you can pre-seed it
differently from params with nnx.Rngs(params=0, dropout=42).
Linen contrast
# Linen — for contrast.
class MyDropout(nn.Module):
rate: float
@nn.compact
def __call__(self, x, deterministic):
if deterministic or self.rate <= 0:
return x
key = self.make_rng("dropout")
keep_prob = 1.0 - self.rate
mask = jax.random.bernoulli(key, p=keep_prob, shape=x.shape)
return jnp.where(mask, x / keep_prob, 0)
# Apply site: must pass rngs every time.
y = model.apply(params, x, deterministic=False, rngs={"dropout": key})
The painful part is the apply site: forget rngs={"dropout": key} and
you get a runtime error. nnx folds the key container into the module
so it’s set up once (at construction) and the call site is just
model(x, deterministic=False).
Common pitfalls
-
Forgetting to scale by
1/keep_prob. Without inverted dropout, training-time activations are smaller than eval-time, and the network has to relearn this gap. -
Returning
xvsx.copy()in eval. JAX arrays are immutable; returningxdirectly is fine — there’s no aliasing concern. -
Sampling with the WRONG
p.bernoulli(p)returns True with probabilityp. We want True ≈ “keep”, sop = keep_prob, NOTrate. -
deterministicas a float. The harness passes 0.0 or 1.0; cast withbool(deterministic >= 0.5)(orint(deterministic)). -
Forgetting the
rate <= 0short-circuit. Withrate = 0,keep_prob = 1, so the mask is all-True and the math still works. But you can save a key and a sample by short-circuiting.
Problem
Write dropout_forward(seed, x, rate, deterministic):
-
Define
MyDropout(nnx.Module)storingself.rate = rateandself.rngs = rngs(no Param/Variable wrappers — these are not trainable state). -
__call__(self, x, deterministic):-
If
deterministicis True orrate <= 0, returnx. -
Otherwise:
key = self.rngs.dropout(),mask = jax.random.bernoulli(key, p=1-rate, shape=x.shape), returnjnp.where(mask, x / (1 - rate), 0.0).
-
If
-
Cast
deterministicfrom float to bool. -
Build with
nnx.Rngs(int(seed)), instantiateMyDropout(rate, rngs), returnmodel(x, deterministic=det).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
rate: float in[0, 1). -
deterministic: float (0.0 or 1.0).
Output: same shape as x.
Hints
Sign in to attempt this problem and view the solution.