medium primitives

Train vs Eval Branches via train Flag

Why this matters

Many layers behave differently in training vs evaluation:

  • Dropout: drops random units during training; pass-through at eval.
  • BatchNorm: uses batch statistics during training; uses running averages at eval.
  • Stochastic depth: randomly skips residual branches during training; always present at eval.

Flax does not have a global “training mode” flag the way PyTorch does (model.train() / model.eval()). Instead, you pass the mode in explicitly as a kwarg to __call__ — typically train: bool or deterministic: bool. The Module passes that value down to layers like Dropout that need it.

This explicitness is intentional: you can compile separate apply calls for train and eval, JIT them independently, and not worry about hidden state sneaking between them.

The train kwarg pattern

class TrainEvalModel(nn.Module):
    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Dense(features=x.shape[-1])(x)
        x = nn.Dropout(rate=0.5, deterministic=not train)(x)
        return x

Note: nn.Dropout takes deterministic= (the inverse of train). When deterministic=True it’s a no-op; when False it samples a mask. So:

  • train=Truedeterministic=False → dropout active
  • train=Falsedeterministic=True → dropout pass-through

init signature

Even at eval time, you have to init the model — and init must register the dropout collection (otherwise apply with train=True would fail to find a key). The standard pattern is to init with train=False but provide BOTH RNG streams anyway:

init_key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)

params = model.init(
    {"params": init_key, "dropout": drop_key},
    x,
    train=False,                    # init mode doesn't matter for dropout shape
)

Why train=False for init?

  • The forward path during init only needs to determine parameter shapes.
  • Dropout has no parameters, so it doesn’t care about the train flag for shape inference.
  • Passing train=False is the convention; some teams pass train=True with no functional difference. Pick one and stick with it.

apply signatures (train and eval)

# Training step — pass dropout key:
y = model.apply(
    params, x,
    train=True,
    rngs={"dropout": jax.random.PRNGKey(7)},
)

# Eval step — no dropout key needed:
y_eval = model.apply(params, x, train=False)

The kwarg signature flows through — if your __call__ takes train, your apply passes train.

Worked example: train and eval give different outputs

model = TrainEvalModel()
key = jax.random.PRNGKey(0)
drop_key = jax.random.PRNGKey(1)
x = jnp.array([1.0, 1.0, 1.0, 1.0])

params = model.init({"params": key, "dropout": drop_key}, x, train=False)

y_train = model.apply(params, x, train=True, rngs={"dropout": drop_key})
y_eval  = model.apply(params, x, train=False)
# y_train and y_eval differ — y_train has half the units zeroed and the rest scaled by 2.

Common pitfalls

  • Forgetting to thread train through sub-modules: if you have nested Modules, every layer that needs the flag must receive it. Pass train=train down the call tree.
  • Using deterministic=train instead of not train: deterministic is the negation. train=True means dropout active means deterministic=False.
  • Init without dropout key: model init with {"params": key} only fails if the model contains Dropout. Always pass both streams when in doubt.

Problem

Build TrainEvalModel with no Module attributes (just __call__):

  1. Dense(features=x.shape[-1])
  2. Dropout(rate=0.5, deterministic=not train)

The function model_train_eval_forward(seed, x, train):

  • Casts train (float) to bool: bool(train >= 0.5).
  • Builds the model.
  • Inits with {"params": PRNGKey(seed), "dropout": PRNGKey(seed + 1)}, passing train=False to init.
  • Applies with the cast train flag and rngs={"dropout": PRNGKey(seed + 1)}.

Inputs:

  • seed: float (cast to int).
  • x: 1-D JAX array.
  • train: float (0.0 or 1.0).

Output: 1-D array, same length as x.

Hints

flax module train-flag

Sign in to attempt this problem and view the solution.