medium primitives

Implement Dropout with RNG Threading

Why this matters

Dropout looks trivially simple in pseudocode: “drop each element with probability p“. But Flax’s strict purity contract makes the implementation interesting:

  1. You can’t access a global RNG — every random draw needs an explicit PRNG key.
  2. Inside __call__, where do you get the key from? Flax provides self.make_rng("dropout") which threads a key from the rngs={"dropout": key} dict you pass to apply.
  3. The forward pass branches on deterministic: at eval time it’s a no-op; at train time it samples a mask AND scales by 1/keep_prob so expected activations don’t change.

Reimplementing Dropout makes you understand all three. Once it clicks, every Flax layer with runtime randomness (StochasticDepth, NoiseInjection, sampling decoders) follows the same pattern.

Inverted dropout: scale by 1/keep_prob

The standard “inverted dropout” used everywhere in modern DL:

keep_prob = 1 - rate
mask = bernoulli(p=keep_prob, shape=x.shape)   # 1 = kept, 0 = dropped
y = where(mask, x / keep_prob, 0)

Why scale up? At train time the EXPECTED value of y matches x:

E[y] = keep_prob * (x / keep_prob) + (1 - keep_prob) * 0 = x

So at eval time you can use y = x (the identity) without any further rescaling — the network was trained to expect inputs at this scale.

self.make_rng(“dropout”)

Inside @nn.compact, you obtain a fresh PRNG key for a named stream:

rng = self.make_rng("dropout")

Flax derives this from the "dropout" key passed to init/apply and a path-dependent fold-in, so different layers get different keys automatically. You don’t need to split or thread keys manually.

Calling make_rng("dropout") from __call__:

  • During init: works only if you passed dropout in the init RNGs dict AND deterministic=False. (If deterministic=True, the branch isn’t executed and make_rng isn’t called.)
  • During apply: works only if you passed rngs={"dropout": key}.

Worked implementation

class MyDropout(nn.Module):
    rate: float

    @nn.compact
    def __call__(self, x, deterministic: bool):
        if deterministic or self.rate == 0.0:
            return x
        keep_prob = 1.0 - self.rate
        rng = self.make_rng("dropout")
        mask = jax.random.bernoulli(rng, p=keep_prob, shape=x.shape)
        return jnp.where(mask, x / keep_prob, 0.0)

Note the early-return for deterministic=True and rate==0.0 — both avoid the unnecessary make_rng call.

Common pitfalls

  • x * mask then divide later: works but bookkeeping is messier. where(mask, x / keep_prob, 0) is the canonical form.
  • Calling make_rng("dropout") outside __call__: only works inside __call__ (or a method called from it) because that’s when Flax has the rngs dict in scope.
  • Forgetting rngs={"dropout": key} in apply: make_rng raises with “no key found for dropout collection”.
  • Same key reused across calls: dropout pattern repeats — defeats randomization. In real loops, split a fresh key per step.

Problem

Implement MyDropout(rate):

  1. If deterministic=True OR rate == 0.0, return x.
  2. Else:
    • keep_prob = 1 - rate.
    • rng = self.make_rng("dropout").
    • mask = jax.random.bernoulli(rng, p=keep_prob, shape=x.shape).
    • Return jnp.where(mask, x / keep_prob, 0.0).

The function does the full init + apply:

  1. Init with {"params": PRNGKey(seed), "dropout": PRNGKey(seed+1)}, deterministic=False.
  2. Apply with the test’s deterministic flag and rngs={"dropout": PRNGKey(seed+1)}.

Inputs:

  • seed: float (cast to int).
  • x: 1-D JAX array.
  • rate: float.
  • deterministic: float (0.0 or 1.0).

Output: 1-D array (same shape as x).

Hints

flax dropout make-rng

Sign in to attempt this problem and view the solution.