hard primitives

NNX Implement BatchNorm

Why this matters

BatchNorm is the layer where nnx’s design philosophy pays off the most. It has FOUR pieces of state at once — two trainable parameters (gamma, beta) and two non-trainable but mutable running statistics (running_mean, running_var) that update during training.

In Linen this required a separate "batch_stats" variable collection, is_initializing() guards inside __call__, declaring mutable=... at apply time, and a return tuple of (out, mutated_state) you had to keep threading through. In nnx, you just write:

self.running_mean.value = momentum * self.running_mean.value + ...

No collections, no mutable flags, no apply-time tuples. The module is the state container, and writing to a Variable.value mutates it in place. Optax-managed parameters and EMA-updated statistics coexist on the same object, distinguished only by the wrapper class.

API

Two nnx.Params for trainables:

  • gamma: per-feature gain, init to ones, shape (D,).
  • beta: per-feature offset, init to zeros, shape (D,).

Two nnx.Variables for running statistics (updated in __call__, not by the optimizer):

  • running_mean: init to zeros, shape (D,).
  • running_var: init to ones, shape (D,). (Var starts at 1, not 0 — variance of zero would zero-divide on the very first eval pass.)

Plus two static attributes for hyperparameters: momentum, eps.

Train vs eval

Driven by the use_running_average flag:

  • use_running_average=False (training):

    1. Compute batch statistics: mu = mean(x, axis=0), var = mean((x - mu)**2, axis=0).
    2. Update running stats with EMA: running_mean.value = momentum * running_mean.value + (1 - momentum) * mu (and likewise for var).
    3. Normalize the input using the BATCH stats.
  • use_running_average=True (eval):

    1. Read the running stats; no update.
    2. Normalize using running stats — no dependence on the current batch.

Both branches finish with gamma * x_hat + beta.

Why axis=0?

BatchNorm normalizes ACROSS the batch, per feature. For input (N, D), axis=0 gives shape (D,) per-feature statistics — one mean and one variance per channel. Compare with LayerNorm (axis=-1) which gives per-sample statistics.

Worked sketch

class MyBatchNorm(nnx.Module):
    def __init__(self, d, momentum, eps, rngs):
        self.gamma = nnx.Param(jnp.ones((d,)))
        self.beta = nnx.Param(jnp.zeros((d,)))
        self.running_mean = nnx.Variable(jnp.zeros((d,)))
        self.running_var = nnx.Variable(jnp.ones((d,)))
        self.momentum = momentum
        self.eps = eps

    def __call__(self, x, use_running_average):
        if use_running_average:
            mu = self.running_mean.value
            var = self.running_var.value
        else:
            mu = jnp.mean(x, axis=0)
            var = jnp.mean((x - mu) ** 2, axis=0)
            # Mutate running stats IN PLACE.
            self.running_mean.value = (
                self.momentum * self.running_mean.value
                + (1.0 - self.momentum) * mu
            )
            self.running_var.value = (
                self.momentum * self.running_var.value
                + (1.0 - self.momentum) * var
            )
        x_hat = (x - mu) / jnp.sqrt(var + self.eps)
        return self.gamma * x_hat + self.beta

The mutation lines are the headline. In Linen the equivalent is:

# Linen — for contrast.
if not self.is_initializing():
    running_mean.value = momentum * running_mean.value + (1 - momentum) * mu
# ... and at apply-site:
out, updates = model.apply({"params": p, "batch_stats": bs}, x,
                            use_running_average=False,
                            mutable=["batch_stats"])
new_bs = updates["batch_stats"]   # carry forward

Three points of friction: the is_initializing() guard, the mutable=["batch_stats"] declaration, and the (out, updates) return tuple. nnx makes them all go away.

What “in place” means under JAX

JAX arrays are immutable, so self.running_mean.value = new_array isn’t really mutating the array — it’s rebinding the wrapper’s .value attribute to a new array. The nnx.Variable wrapper provides the appearance of mutation while keeping JAX semantics underneath. Under nnx.split, the new value is what gets serialized.

Common pitfalls

  • running_var initialized to zeros. First eval pass would divide by sqrt(0 + eps) ≈ sqrt(eps), blowing up the output. Init to ones.
  • Updating running stats in eval mode. The if use_running_average branch must NOT mutate; only the train branch updates.
  • Using running stats in train mode. Train mode normalizes by batch stats. The running stats are only for eval.
  • use_running_average arriving as float. The harness passes it as 0.0 / 1.0; cast to bool with bool(flag >= 0.5) (or int(flag)).
  • Stats over the wrong axis. axis=-1 is LayerNorm, not BatchNorm. For 2-D (N, D) input, BatchNorm uses axis=0.

Problem

Write batchnorm_forward(seed, x, use_running_average):

  1. Define MyBatchNorm(nnx.Module) with the four state attributes above (gamma, beta as nnx.Param; running_mean, running_var as nnx.Variable), plus momentum=0.9 and eps=1e-5 as plain attributes.
  2. __call__(x, use_running_average):
    • If True: read running stats, normalize, no mutation.
    • Else: compute batch stats, update running stats via EMA, normalize.
    • Return gamma * x_hat + beta.
  3. Cast use_running_average from float to bool with bool(flag >= 0.5).
  4. Build with nnx.Rngs(int(seed)), instantiate MyBatchNorm(d=x.shape[-1], ...), return model(x, use_running_average=use_run).reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x: 2-D (N, D).
  • use_running_average: float (0.0 or 1.0).

Output: 1-D flattened.

Hints

flax nnx batchnorm mutable-state

Sign in to attempt this problem and view the solution.