We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train vs Eval Branches via train Flag
Why this matters
Many layers behave differently in training vs evaluation:
- Dropout: drops random units during training; pass-through at eval.
- BatchNorm: uses batch statistics during training; uses running averages at eval.
- Stochastic depth: randomly skips residual branches during training; always present at eval.
Flax does not have a global “training mode” flag the way PyTorch does
(model.train() / model.eval()). Instead, you pass the mode in
explicitly as a kwarg to __call__ — typically train: bool or
deterministic: bool. The Module passes that value down to layers like
Dropout that need it.
This explicitness is intentional: you can compile separate apply calls for
train and eval, JIT them independently, and not worry about hidden state
sneaking between them.
The train kwarg pattern
class TrainEvalModel(nn.Module):
@nn.compact
def __call__(self, x, train: bool):
x = nn.Dense(features=x.shape[-1])(x)
x = nn.Dropout(rate=0.5, deterministic=not train)(x)
return x
Note: nn.Dropout takes deterministic= (the inverse of train). When
deterministic=True it’s a no-op; when False it samples a mask. So:
-
train=True→deterministic=False→ dropout active -
train=False→deterministic=True→ dropout pass-through
init signature
Even at eval time, you have to init the model — and init must register the
dropout collection (otherwise apply with train=True would fail to find
a key). The standard pattern is to init with train=False but provide
BOTH RNG streams anyway:
init_key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)
params = model.init(
{"params": init_key, "dropout": drop_key},
x,
train=False, # init mode doesn't matter for dropout shape
)
Why train=False for init?
- The forward path during init only needs to determine parameter shapes.
- Dropout has no parameters, so it doesn’t care about the train flag for shape inference.
-
Passing
train=Falseis the convention; some teams passtrain=Truewith no functional difference. Pick one and stick with it.
apply signatures (train and eval)
# Training step — pass dropout key:
y = model.apply(
params, x,
train=True,
rngs={"dropout": jax.random.PRNGKey(7)},
)
# Eval step — no dropout key needed:
y_eval = model.apply(params, x, train=False)
The kwarg signature flows through — if your __call__ takes train, your
apply passes train.
Worked example: train and eval give different outputs
model = TrainEvalModel()
key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)
x = jnp.array([1.0, 1.0, 1.0, 1.0])
params = model.init({"params": key, "dropout": drop_key}, x, train=False)
y_train = model.apply(params, x, train=True, rngs={"dropout": drop_key})
y_eval = model.apply(params, x, train=False)
# y_train and y_eval differ — y_train has half the units zeroed and the rest scaled by 2.
Common pitfalls
-
Forgetting to thread
trainthrough sub-modules: if you have nested Modules, every layer that needs the flag must receive it. Passtrain=traindown the call tree. -
Using
deterministic=traininstead ofnot train:deterministicis the negation.train=Truemeans dropout active meansdeterministic=False. -
Init without dropout key: model init with
{"params": key}only fails if the model contains Dropout. Always pass both streams when in doubt.
Problem
Build TrainEvalModel with no Module attributes (just __call__):
-
Dense(features=x.shape[-1]) -
Dropout(rate=0.5, deterministic=not train)
The function model_train_eval_forward(seed, x, train):
-
Casts
train(float) to bool:bool(train >= 0.5). - Builds the model.
-
Inits with
{"params": PRNGKey(seed), "dropout": PRNGKey(seed + 1)}, passingtrain=Falsetoinit. -
Applies with the cast train flag and
rngs={"dropout": PRNGKey(seed + 1)}.
Inputs:
-
seed: float (cast to int). -
x: 1-D JAX array. -
train: float (0.0 or 1.0).
Output: 1-D array, same length as x.
Hints
Sign in to attempt this problem and view the solution.