medium primitives

NNX Train/Eval Flag

Why this matters

Many layers behave differently between training and evaluation:

  • Dropout zeros activations during training, no-ops at eval.
  • BatchNorm uses batch statistics during training, running averages at eval.
  • Variational layers add noise during training only.
  • Stochastic-depth skips a block with some probability at training.

The standard nnx idiom for this is a train: bool keyword argument to __call__ that flows down to every stochastic sub-layer. There’s no global is_training() flag — explicitness is the whole point. If a layer doesn’t see the flag, it can’t be stochastic.

Compare with PyTorch’s model.train() / model.eval() modes (which mutate every module recursively) or Linen’s deterministic=True convention (which is plumbed automatically by nn.Dropout and friends). nnx prefers the explicit kwarg pattern: you pass train=True/False every forward call.

API: train kwarg in call

class MyLayer(nnx.Module):
    def __init__(self, rngs):
        self.rngs = rngs           # save the container; we'll draw keys later

    def __call__(self, x, train: bool):
        if train:
            key = self.rngs.dropout()
            noise = jax.random.normal(key, x.shape)
            return x + 0.1 * noise
        return x

Two API points to notice:

  1. self.rngs = rngs — saving the Rngs container as an attribute lets __call__ pull keys later. (For nnx 0.12.6, this stores the rngs state in the module’s state pytree, so split/merge round-trips preserve the RNG.) This is the standard pattern for layers that need fresh randomness on every forward call (Dropout, stochastic depth, etc.).
  2. rngs.dropout() — the dropout stream is conventional for dropout-like noise. With a single-int Rngs constructor, .dropout() falls back to the default stream and returns a fresh key on each call.

Worked example

rngs = nnx.Rngs(0)
model = MyLayer(rngs=rngs)

out_eval  = model(x, train=False)        # x unchanged
out_train = model(x, train=True)         # x + small noise
print((out_train - out_eval).max())      # nonzero

Because the Rngs container is stateful, calling model(x, train=True) twice returns different outputs (different keys → different noise). Pure-eval calls are deterministic.

Cast train from float to bool

The harness passes train as 0.0 or 1.0. Convert:

train_b = bool(train >= 0.5)

Common pitfalls

  • Forgetting the train kwarg. Without it, your layer can’t tell training from eval. nnx layers like nnx.Dropout and nnx.BatchNorm take it explicitly.
  • Not saving rngs. __call__ runs after __init__; it can’t access local variables from __init__. Save them as attributes.
  • Pulling a key in __init__ and reusing it forever. Then every forward gets the SAME noise — defeating the purpose. Pull a fresh key each forward call.
  • Using rngs.params() for noise. Convention is .dropout() for stochastic forward-time draws; .params() is for one-shot init. They’re different streams unless you constructed Rngs with a single int (in which case they’re the same default stream).

Problem

Write train_eval_forward(seed, x, train):

  1. Cast train_b = bool(train >= 0.5).
  2. Define NoiseLayer(nnx.Module). __init__ saves self.rngs = rngs. __call__(self, x, train) does:
    • if train: key = self.rngs.dropout(); noise = jax.random.normal(key, x.shape); return x + 0.1 * noise
    • else: return x unchanged.
  3. Build nnx.Rngs(int(seed)), instantiate NoiseLayer, return model(x, train=train_b).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • train: float — 0.0 or 1.0.

Output: 1-D array of same shape as x.

Hints

flax nnx train-eval dropout-style

Sign in to attempt this problem and view the solution.