We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement BatchNorm
Why this matters
BatchNorm is the layer where nnx’s design philosophy pays off the
most. It has FOUR pieces of state at once — two trainable parameters
(gamma, beta) and two non-trainable but mutable running statistics
(running_mean, running_var) that update during training.
In Linen this required a separate "batch_stats" variable collection,
is_initializing() guards inside __call__, declaring mutable=...
at apply time, and a return tuple of (out, mutated_state) you had to
keep threading through. In nnx, you just write:
self.running_mean.value = momentum * self.running_mean.value + ...
No collections, no mutable flags, no apply-time tuples. The module
is the state container, and writing to a Variable.value mutates it
in place. Optax-managed parameters and EMA-updated statistics coexist
on the same object, distinguished only by the wrapper class.
API
Two nnx.Params for trainables:
-
gamma: per-feature gain, init to ones, shape(D,). -
beta: per-feature offset, init to zeros, shape(D,).
Two nnx.Variables for running statistics (updated in __call__,
not by the optimizer):
-
running_mean: init to zeros, shape(D,). -
running_var: init to ones, shape(D,). (Var starts at 1, not 0 — variance of zero would zero-divide on the very first eval pass.)
Plus two static attributes for hyperparameters: momentum, eps.
Train vs eval
Driven by the use_running_average flag:
-
use_running_average=False(training):-
Compute batch statistics:
mu = mean(x, axis=0),var = mean((x - mu)**2, axis=0). -
Update running stats with EMA:
running_mean.value = momentum * running_mean.value + (1 - momentum) * mu(and likewise for var). - Normalize the input using the BATCH stats.
-
Compute batch statistics:
-
use_running_average=True(eval):- Read the running stats; no update.
- Normalize using running stats — no dependence on the current batch.
Both branches finish with gamma * x_hat + beta.
Why axis=0?
BatchNorm normalizes ACROSS the batch, per feature. For input (N, D),
axis=0 gives shape (D,) per-feature statistics — one mean and one
variance per channel. Compare with LayerNorm (axis=-1) which gives
per-sample statistics.
Worked sketch
class MyBatchNorm(nnx.Module):
def __init__(self, d, momentum, eps, rngs):
self.gamma = nnx.Param(jnp.ones((d,)))
self.beta = nnx.Param(jnp.zeros((d,)))
self.running_mean = nnx.Variable(jnp.zeros((d,)))
self.running_var = nnx.Variable(jnp.ones((d,)))
self.momentum = momentum
self.eps = eps
def __call__(self, x, use_running_average):
if use_running_average:
mu = self.running_mean.value
var = self.running_var.value
else:
mu = jnp.mean(x, axis=0)
var = jnp.mean((x - mu) ** 2, axis=0)
# Mutate running stats IN PLACE.
self.running_mean.value = (
self.momentum * self.running_mean.value
+ (1.0 - self.momentum) * mu
)
self.running_var.value = (
self.momentum * self.running_var.value
+ (1.0 - self.momentum) * var
)
x_hat = (x - mu) / jnp.sqrt(var + self.eps)
return self.gamma * x_hat + self.beta
The mutation lines are the headline. In Linen the equivalent is:
# Linen — for contrast.
if not self.is_initializing():
running_mean.value = momentum * running_mean.value + (1 - momentum) * mu
# ... and at apply-site:
out, updates = model.apply({"params": p, "batch_stats": bs}, x,
use_running_average=False,
mutable=["batch_stats"])
new_bs = updates["batch_stats"] # carry forward
Three points of friction: the is_initializing() guard, the
mutable=["batch_stats"] declaration, and the (out, updates) return
tuple. nnx makes them all go away.
What “in place” means under JAX
JAX arrays are immutable, so self.running_mean.value = new_array
isn’t really mutating the array — it’s rebinding the wrapper’s
.value attribute to a new array. The nnx.Variable wrapper provides
the appearance of mutation while keeping JAX semantics underneath.
Under nnx.split, the new value is what gets serialized.
Common pitfalls
-
running_varinitialized to zeros. First eval pass would divide bysqrt(0 + eps) ≈ sqrt(eps), blowing up the output. Init to ones. -
Updating running stats in eval mode. The
if use_running_averagebranch must NOT mutate; only the train branch updates. - Using running stats in train mode. Train mode normalizes by batch stats. The running stats are only for eval.
-
use_running_averagearriving as float. The harness passes it as 0.0 / 1.0; cast to bool withbool(flag >= 0.5)(orint(flag)). -
Stats over the wrong axis.
axis=-1is LayerNorm, not BatchNorm. For 2-D(N, D)input, BatchNorm usesaxis=0.
Problem
Write batchnorm_forward(seed, x, use_running_average):
-
Define
MyBatchNorm(nnx.Module)with the four state attributes above (gamma,betaasnnx.Param;running_mean,running_varasnnx.Variable), plusmomentum=0.9andeps=1e-5as plain attributes. -
__call__(x, use_running_average):- If True: read running stats, normalize, no mutation.
- Else: compute batch stats, update running stats via EMA, normalize.
-
Return
gamma * x_hat + beta.
-
Cast
use_running_averagefrom float to bool withbool(flag >= 0.5). -
Build with
nnx.Rngs(int(seed)), instantiateMyBatchNorm(d=x.shape[-1], ...), returnmodel(x, use_running_average=use_run).reshape(-1).
Inputs:
-
seed: int (passed as float). -
x: 2-D(N, D). -
use_running_average: float (0.0 or 1.0).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.