medium primitives

NNX Implement Dropout

Why this matters

Dropout is the canonical example of a layer that needs three things at once: a hyperparameter (the drop rate), a fresh PRNG key per call (so every forward gets a different mask), and a train/eval flag (eval is a no-op). In Linen this required passing rngs={"dropout": key} at every apply call site — easy to forget, and a frequent source of ‘eval looks fine, train doesn’t’ bugs.

nnx threads keys through an nnx.Rngs container that the module owns as an attribute. Each call to self.rngs.dropout() returns a fresh key (the container splits internally) — same as parameter init, just a different stream name. No threading at the call site.

Inverted dropout

Standard practice is inverted dropout: scale the kept activations UP at training time so that the expected output equals the input. Then at eval, the layer is a no-op:

keep_prob = 1.0 - rate
mask = jax.random.bernoulli(key, p=keep_prob, shape=x.shape)
out = jnp.where(mask, x / keep_prob, 0.0)

Without the /keep_prob rescale, the network would see different activation magnitudes between train and eval; inverted dropout makes them match in expectation, so the eval branch can simply return x.

API: rngs as a member

class MyDropout(nnx.Module):
    def __init__(self, rate, rngs):
        self.rate = rate
        self.rngs = rngs                    # save the container

    def __call__(self, x, deterministic):
        if deterministic or self.rate <= 0.0:
            return x
        keep_prob = 1.0 - self.rate
        key = self.rngs.dropout()           # fresh key per call
        mask = jax.random.bernoulli(key, p=keep_prob, shape=x.shape)
        return jnp.where(mask, x / keep_prob, 0.0)

rngs.dropout() is a method on the rngs container that returns a new key every time it’s called. Internally it tracks a counter and splits a master key — you don’t have to manage that yourself.

The dropout name is just a stream name. nnx.Rngs(0) registers a params stream by default; calls to .dropout() lazily create a dropout stream the first time you ask. Later you can pre-seed it differently from params with nnx.Rngs(params=0, dropout=42).

Linen contrast

# Linen — for contrast.
class MyDropout(nn.Module):
    rate: float
    @nn.compact
    def __call__(self, x, deterministic):
        if deterministic or self.rate <= 0:
            return x
        key = self.make_rng("dropout")
        keep_prob = 1.0 - self.rate
        mask = jax.random.bernoulli(key, p=keep_prob, shape=x.shape)
        return jnp.where(mask, x / keep_prob, 0)

# Apply site: must pass rngs every time.
y = model.apply(params, x, deterministic=False, rngs={"dropout": key})

The painful part is the apply site: forget rngs={"dropout": key} and you get a runtime error. nnx folds the key container into the module so it’s set up once (at construction) and the call site is just model(x, deterministic=False).

Common pitfalls

  • Forgetting to scale by 1/keep_prob. Without inverted dropout, training-time activations are smaller than eval-time, and the network has to relearn this gap.
  • Returning x vs x.copy() in eval. JAX arrays are immutable; returning x directly is fine — there’s no aliasing concern.
  • Sampling with the WRONG p. bernoulli(p) returns True with probability p. We want True ≈ “keep”, so p = keep_prob, NOT rate.
  • deterministic as a float. The harness passes 0.0 or 1.0; cast with bool(deterministic >= 0.5) (or int(deterministic)).
  • Forgetting the rate <= 0 short-circuit. With rate = 0, keep_prob = 1, so the mask is all-True and the math still works. But you can save a key and a sample by short-circuiting.

Problem

Write dropout_forward(seed, x, rate, deterministic):

  1. Define MyDropout(nnx.Module) storing self.rate = rate and self.rngs = rngs (no Param/Variable wrappers — these are not trainable state).
  2. __call__(self, x, deterministic):
    • If deterministic is True or rate <= 0, return x.
    • Otherwise: key = self.rngs.dropout(), mask = jax.random.bernoulli(key, p=1-rate, shape=x.shape), return jnp.where(mask, x / (1 - rate), 0.0).
  3. Cast deterministic from float to bool.
  4. Build with nnx.Rngs(int(seed)), instantiate MyDropout(rate, rngs), return model(x, deterministic=det).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • rate: float in [0, 1).
  • deterministic: float (0.0 or 1.0).

Output: same shape as x.

Hints

flax nnx dropout rngs

Sign in to attempt this problem and view the solution.