hard primitives

Train with Mutable batch_stats

Why this matters

Most “tutorial” Flax models — Dense, LayerNorm, attention — have only params. They’re easy: gradients flow through every variable.

But the moment you add BatchNorm (or any module with running stats: e.g. RunningMean for EMA-tracked metrics, online normalization, etc.), the variable tree splits in two:

  • "params" — trainable: gradients flow, optimizer updates them.
  • "batch_stats" — mutable but not trainable: changes between calls, but gradients are blocked, optimizer never touches them.

Training a model with this split is the canonical Flax pattern for real CNNs (ResNet) and many vision Transformers. Get this right and you can train any architecture in the wild.

The split

model.init(...) returns a dict {"params": ..., "batch_stats": ...}. Pull them apart immediately:

variables = model.init(rng, x, train=False)
params       = variables["params"]
batch_stats  = variables["batch_stats"]

The optimizer only sees params:

opt_state = tx.init(params)        # NOT tx.init(variables)

The training step (loss with aux)

loss_fn must return BOTH the scalar loss (for grad) AND the new batch_stats (so we can pass them forward). JAX has a flag for this:

def loss_fn(params, batch_stats, x, y):
    out, updated = model.apply(
        {"params": params, "batch_stats": batch_stats},
        x, train=True, mutable=["batch_stats"]
    )
    preds = out.reshape(-1)
    loss = jnp.mean((preds - y) ** 2)
    return loss, updated["batch_stats"]   # (scalar, aux)

# has_aux=True tells value_and_grad: "the second return is auxiliary;
# don't differentiate through it; pass it back to me."
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, new_batch_stats), grads = grad_fn(params, batch_stats, x, y)

grads has the SAME pytree shape as paramsbatch_stats is not in the input we differentiate against (it’s argument 1, not 0), so it doesn’t appear in the gradient. Good.

Threading the new state

After the step, both params and batch_stats get rebound:

updates, opt_state = tx.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)     # via optimizer
batch_stats = new_batch_stats                     # direct overwrite

params get the optimizer’s idea of an update. batch_stats get the EMA-updated values that came out of the forward pass.

Eval mode

For eval / metric reporting, use train=False (which sets use_running_average=True inside BatchNorm). Pass the SAME variables dict (since BatchNorm reads from batch_stats but doesn’t write):

out, _ = model.apply(
    {"params": params, "batch_stats": batch_stats},
    x, train=False, mutable=["batch_stats"]
)

Common pitfalls

  • Initializing the optimizer on variables: optax doesn’t know what to do with the batch_stats collection. tx.init(params) only.
  • Forgetting mutable=["batch_stats"] on apply: error — BatchNorm tries to write but the collection isn’t declared mutable.
  • Treating batch_stats like params and feeding it to the optimizer. It would silently get optimized to noise.
  • Forgetting has_aux=True: value_and_grad will try to differentiate the second return.

Problem

The model is a small Dense → BatchNorm → ReLU → Dense(1) MLP. Implement one training step that:

  1. Inits params and batch_stats from seed and x_batch.
  2. Defines loss_fn returning (loss, new_batch_stats).
  3. Uses jax.value_and_grad(loss_fn, has_aux=True) to get ((loss, new_batch_stats), grads).
  4. Applies grads to params via optax.sgd(lr) and overwrites batch_stats.
  5. Computes the post-step MSE in eval mode (train=False) and returns it as 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x_batch: 2-D (N, D).
  • y_batch: 1-D (N,).
  • lr: float.

Output: 1-D (1,)[final_eval_loss].

Hints

flax batchnorm mutable-state training

Sign in to attempt this problem and view the solution.