hard primitives

NNX Train with BatchNorm

Why this matters

Training a model with BatchNorm is the original headache that motivated nnx’s design. In Linen, BN’s running statistics live in a separate "batch_stats" collection, you must declare mutable=["batch_stats"] at apply time, the apply returns a (out, updates) tuple, and you have to thread updates["batch_stats"] forward through every step alongside params. Get any link wrong and your running stats silently never update.

In nnx, the running stats live as nnx.BatchStat Variables on the nnx.BatchNorm module. Calling bn(x, use_running_average=False) mutates them in place. The optimizer doesn’t see them (because wrt=nnx.Param filters them out). The training loop is identical to a model without BatchNorm — no extra state to thread.

The model

class BNModel(nnx.Module):
    def __init__(self, d_in, hidden, d_out, rngs):
        self.linear1 = nnx.Linear(d_in, hidden, rngs=rngs)
        self.bn = nnx.BatchNorm(hidden, rngs=rngs)
        self.linear2 = nnx.Linear(hidden, d_out, rngs=rngs)

    def __call__(self, x, use_running_average):
        x = self.linear1(x)
        x = self.bn(x, use_running_average=use_running_average)
        x = jax.nn.relu(x)
        return self.linear2(x)

The training loop

optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)

def loss_fn(model, x, y):
    pred = model(x, use_running_average=False)   # train mode
    return jnp.mean((pred - y) ** 2)

for step in range(num_steps):
    i = step % num_batches
    xb = x_batches_flat[i*bs:(i+1)*bs]
    yb = y_batches_flat[i*bs:(i+1)*bs]
    _, grads = nnx.value_and_grad(loss_fn)(model, xb, yb)
    optimizer.update(model, grads)

The wrt=nnx.Param filter is doing real work here: nnx.value_and_grad only differentiates with respect to nnx.Params, and the optimizer only updates nnx.Params. The BN running stats are nnx.BatchStat Variables — neither differentiated nor updated by optax — but they ARE mutated by __call__ itself.

The eval forward

pred_eval = model(x_eval, use_running_average=True)

With use_running_average=True, the BN reads its running_mean and running_var instead of computing batch stats. The result is deterministic and independent of the eval batch’s distribution.

Compared to Linen

The Linen equivalent looks like:

# Linen — for contrast.
out, updates = state.apply_fn(
    {"params": state.params, "batch_stats": state.batch_stats},
    x, use_running_average=False,
    mutable=["batch_stats"],
)
state = state.replace(batch_stats=updates["batch_stats"])

Three things to thread (params, batch_stats, mutable=...), one return tuple to unpack, and a state.replace to bolt the new stats back on. nnx makes all of that go away.

Common pitfalls

  • Forgetting use_running_average=False during training. Then the BN reads stale running stats from the last train pass — possibly zeros at step 0 — and never accumulates fresh batch statistics.
  • Setting wrt=nnx.Variable instead of nnx.Param on the optimizer. Then optax tries to “step” the running stats too, which is wrong. Stick to wrt=nnx.Param.
  • Calling the model on the WHOLE flat batch and only mutating running stats once. Each batch update is one EMA tick; pseudo-batched eval (calling on the entire dataset) gives one noisy step. We slice into mini-batches.
  • Computing eval with use_running_average=False. Then eval depends on the eval batch’s stats, which is the OPPOSITE of what you want.

Problem

Implement train_with_bn(seed, x_batches_flat, y_batches_flat, batch_size, lr, num_steps):

  1. Define BNModel(nnx.Module) with linear1: nnx.Linear(D_in, 8), bn: nnx.BatchNorm(8), linear2: nnx.Linear(8, D_out). __call__(x, use_running_average) is linear1 -> bn -> relu -> linear2.
  2. Build with nnx.Rngs(int(seed)) and wrap in nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  3. Slice x_batches_flat into chunks of batch_size. Loop int(num_steps) times; each step picks the next batch (cycle with step % num_batches). Compute grads with use_running_average=False, then optimizer.update(model, grads).
  4. After training, call model(x_batches_flat, use_running_average=True) on the FULL flat batch. Return jnp.array([float(jnp.mean(pred))]) as 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x_batches_flat: 2-D (N, D_in).
  • y_batches_flat: 2-D (N, D_out).
  • batch_size: float (cast to int). Divides N exactly.
  • lr: float.
  • num_steps: float (cast to int).

Output: 1-D (1,)[mean_of_eval_pred].

Hints

flax nnx batchnorm training mutable-state

Sign in to attempt this problem and view the solution.