We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Train/Eval Flag
Why this matters
Many layers behave differently between training and evaluation:
- Dropout zeros activations during training, no-ops at eval.
- BatchNorm uses batch statistics during training, running averages at eval.
- Variational layers add noise during training only.
- Stochastic-depth skips a block with some probability at training.
The standard nnx idiom for this is a train: bool keyword argument
to __call__ that flows down to every stochastic sub-layer. There’s
no global is_training() flag — explicitness is the whole point. If
a layer doesn’t see the flag, it can’t be stochastic.
Compare with PyTorch’s model.train() / model.eval() modes
(which mutate every module recursively) or Linen’s deterministic=True
convention (which is plumbed automatically by nn.Dropout and friends).
nnx prefers the explicit kwarg pattern: you pass train=True/False
every forward call.
API: train kwarg in call
class MyLayer(nnx.Module):
def __init__(self, rngs):
self.rngs = rngs # save the container; we'll draw keys later
def __call__(self, x, train: bool):
if train:
key = self.rngs.dropout()
noise = jax.random.normal(key, x.shape)
return x + 0.1 * noise
return x
Two API points to notice:
-
self.rngs = rngs— saving the Rngs container as an attribute lets__call__pull keys later. (For nnx 0.12.6, this stores the rngs state in the module’s state pytree, so split/merge round-trips preserve the RNG.) This is the standard pattern for layers that need fresh randomness on every forward call (Dropout, stochastic depth, etc.). -
rngs.dropout()— thedropoutstream is conventional for dropout-like noise. With a single-int Rngs constructor,.dropout()falls back to the default stream and returns a fresh key on each call.
Worked example
rngs = nnx.Rngs(0)
model = MyLayer(rngs=rngs)
out_eval = model(x, train=False) # x unchanged
out_train = model(x, train=True) # x + small noise
print((out_train - out_eval).max()) # nonzero
Because the Rngs container is stateful, calling model(x, train=True)
twice returns different outputs (different keys → different noise).
Pure-eval calls are deterministic.
Cast train from float to bool
The harness passes train as 0.0 or 1.0. Convert:
train_b = bool(train >= 0.5)
Common pitfalls
-
Forgetting the
trainkwarg. Without it, your layer can’t tell training from eval. nnx layers likennx.Dropoutandnnx.BatchNormtake it explicitly. -
Not saving
rngs.__call__runs after__init__; it can’t access local variables from__init__. Save them as attributes. -
Pulling a key in
__init__and reusing it forever. Then every forward gets the SAME noise — defeating the purpose. Pull a fresh key each forward call. -
Using
rngs.params()for noise. Convention is.dropout()for stochastic forward-time draws;.params()is for one-shot init. They’re different streams unless you constructed Rngs with a single int (in which case they’re the same default stream).
Problem
Write train_eval_forward(seed, x, train):
-
Cast
train_b = bool(train >= 0.5). -
Define
NoiseLayer(nnx.Module).__init__savesself.rngs = rngs.__call__(self, x, train)does:-
if
train:key = self.rngs.dropout(); noise = jax.random.normal(key, x.shape); return x + 0.1 * noise -
else: return
xunchanged.
-
if
-
Build
nnx.Rngs(int(seed)), instantiateNoiseLayer, returnmodel(x, train=train_b).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
train: float — 0.0 or 1.0.
Output: 1-D array of same shape as x.
Hints
Sign in to attempt this problem and view the solution.