We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Train with BatchNorm
Why this matters
Training a model with BatchNorm is the original headache that
motivated nnx’s design. In Linen, BN’s running statistics live in a
separate "batch_stats" collection, you must declare
mutable=["batch_stats"] at apply time, the apply returns a
(out, updates) tuple, and you have to thread updates["batch_stats"]
forward through every step alongside params. Get any link wrong
and your running stats silently never update.
In nnx, the running stats live as nnx.BatchStat Variables on the
nnx.BatchNorm module. Calling bn(x, use_running_average=False)
mutates them in place. The optimizer doesn’t see them (because
wrt=nnx.Param filters them out). The training loop is identical
to a model without BatchNorm — no extra state to thread.
The model
class BNModel(nnx.Module):
def __init__(self, d_in, hidden, d_out, rngs):
self.linear1 = nnx.Linear(d_in, hidden, rngs=rngs)
self.bn = nnx.BatchNorm(hidden, rngs=rngs)
self.linear2 = nnx.Linear(hidden, d_out, rngs=rngs)
def __call__(self, x, use_running_average):
x = self.linear1(x)
x = self.bn(x, use_running_average=use_running_average)
x = jax.nn.relu(x)
return self.linear2(x)
The training loop
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
def loss_fn(model, x, y):
pred = model(x, use_running_average=False) # train mode
return jnp.mean((pred - y) ** 2)
for step in range(num_steps):
i = step % num_batches
xb = x_batches_flat[i*bs:(i+1)*bs]
yb = y_batches_flat[i*bs:(i+1)*bs]
_, grads = nnx.value_and_grad(loss_fn)(model, xb, yb)
optimizer.update(model, grads)
The wrt=nnx.Param filter is doing real work here:
nnx.value_and_grad only differentiates with respect to nnx.Params,
and the optimizer only updates nnx.Params. The BN running stats are
nnx.BatchStat Variables — neither differentiated nor updated by
optax — but they ARE mutated by __call__ itself.
The eval forward
pred_eval = model(x_eval, use_running_average=True)
With use_running_average=True, the BN reads its running_mean and
running_var instead of computing batch stats. The result is
deterministic and independent of the eval batch’s distribution.
Compared to Linen
The Linen equivalent looks like:
# Linen — for contrast.
out, updates = state.apply_fn(
{"params": state.params, "batch_stats": state.batch_stats},
x, use_running_average=False,
mutable=["batch_stats"],
)
state = state.replace(batch_stats=updates["batch_stats"])
Three things to thread (params, batch_stats, mutable=...), one
return tuple to unpack, and a state.replace to bolt the new stats
back on. nnx makes all of that go away.
Common pitfalls
-
Forgetting
use_running_average=Falseduring training. Then the BN reads stale running stats from the last train pass — possibly zeros at step 0 — and never accumulates fresh batch statistics. -
Setting
wrt=nnx.Variableinstead ofnnx.Paramon the optimizer. Then optax tries to “step” the running stats too, which is wrong. Stick towrt=nnx.Param. - Calling the model on the WHOLE flat batch and only mutating running stats once. Each batch update is one EMA tick; pseudo-batched eval (calling on the entire dataset) gives one noisy step. We slice into mini-batches.
-
Computing eval with
use_running_average=False. Then eval depends on the eval batch’s stats, which is the OPPOSITE of what you want.
Problem
Implement train_with_bn(seed, x_batches_flat, y_batches_flat, batch_size, lr, num_steps):
-
Define
BNModel(nnx.Module)withlinear1: nnx.Linear(D_in, 8),bn: nnx.BatchNorm(8),linear2: nnx.Linear(8, D_out).__call__(x, use_running_average)islinear1 -> bn -> relu -> linear2. -
Build with
nnx.Rngs(int(seed))and wrap innnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Slice
x_batches_flatinto chunks ofbatch_size. Loopint(num_steps)times; each step picks the next batch (cycle withstep % num_batches). Compute grads withuse_running_average=False, thenoptimizer.update(model, grads). -
After training, call
model(x_batches_flat, use_running_average=True)on the FULL flat batch. Returnjnp.array([float(jnp.mean(pred))])as 1-D(1,).
Inputs:
-
seed: float (cast to int). -
x_batches_flat: 2-D(N, D_in). -
y_batches_flat: 2-D(N, D_out). -
batch_size: float (cast to int). Divides N exactly. -
lr: float. -
num_steps: float (cast to int).
Output: 1-D (1,) — [mean_of_eval_pred].
Hints
Sign in to attempt this problem and view the solution.