We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train with Mutable batch_stats
Why this matters
Most “tutorial” Flax models — Dense, LayerNorm, attention — have
only params. They’re easy: gradients flow through every variable.
But the moment you add BatchNorm (or any module with running stats:
e.g. RunningMean for EMA-tracked metrics, online normalization, etc.),
the variable tree splits in two:
-
"params"— trainable: gradients flow, optimizer updates them. -
"batch_stats"— mutable but not trainable: changes between calls, but gradients are blocked, optimizer never touches them.
Training a model with this split is the canonical Flax pattern for real CNNs (ResNet) and many vision Transformers. Get this right and you can train any architecture in the wild.
The split
model.init(...) returns a dict {"params": ..., "batch_stats": ...}.
Pull them apart immediately:
variables = model.init(rng, x, train=False)
params = variables["params"]
batch_stats = variables["batch_stats"]
The optimizer only sees params:
opt_state = tx.init(params) # NOT tx.init(variables)
The training step (loss with aux)
loss_fn must return BOTH the scalar loss (for grad) AND the new
batch_stats (so we can pass them forward). JAX has a flag for this:
def loss_fn(params, batch_stats, x, y):
out, updated = model.apply(
{"params": params, "batch_stats": batch_stats},
x, train=True, mutable=["batch_stats"]
)
preds = out.reshape(-1)
loss = jnp.mean((preds - y) ** 2)
return loss, updated["batch_stats"] # (scalar, aux)
# has_aux=True tells value_and_grad: "the second return is auxiliary;
# don't differentiate through it; pass it back to me."
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, new_batch_stats), grads = grad_fn(params, batch_stats, x, y)
grads has the SAME pytree shape as params — batch_stats is not
in the input we differentiate against (it’s argument 1, not 0), so
it doesn’t appear in the gradient. Good.
Threading the new state
After the step, both params and batch_stats get rebound:
updates, opt_state = tx.update(grads, opt_state, params)
params = optax.apply_updates(params, updates) # via optimizer
batch_stats = new_batch_stats # direct overwrite
params get the optimizer’s idea of an update. batch_stats get the
EMA-updated values that came out of the forward pass.
Eval mode
For eval / metric reporting, use train=False (which sets
use_running_average=True inside BatchNorm). Pass the SAME variables
dict (since BatchNorm reads from batch_stats but doesn’t write):
out, _ = model.apply(
{"params": params, "batch_stats": batch_stats},
x, train=False, mutable=["batch_stats"]
)
Common pitfalls
-
Initializing the optimizer on
variables: optax doesn’t know what to do with thebatch_statscollection.tx.init(params)only. -
Forgetting
mutable=["batch_stats"]on apply: error — BatchNorm tries to write but the collection isn’t declared mutable. -
Treating
batch_statslike params and feeding it to the optimizer. It would silently get optimized to noise. -
Forgetting
has_aux=True:value_and_gradwill try to differentiate the second return.
Problem
The model is a small Dense → BatchNorm → ReLU → Dense(1) MLP.
Implement one training step that:
-
Inits
paramsandbatch_statsfromseedandx_batch. -
Defines
loss_fnreturning(loss, new_batch_stats). -
Uses
jax.value_and_grad(loss_fn, has_aux=True)to get((loss, new_batch_stats), grads). -
Applies grads to
paramsviaoptax.sgd(lr)and overwritesbatch_stats. -
Computes the post-step MSE in eval mode (
train=False) and returns it as 1-D(1,).
Inputs:
-
seed: float (cast to int). -
x_batch: 2-D(N, D). -
y_batch: 1-D(N,). -
lr: float.
Output: 1-D (1,) — [final_eval_loss].
Hints
Sign in to attempt this problem and view the solution.