hard primitives

Implement BatchNorm with Mutable batch_stats

Why this matters

BatchNorm is the most-cited normalization technique in modern DL — and the most subtle to implement in Flax because it has a piece every other layer doesn’t: mutable, non-parameter state.

The running mean and running variance are NOT trainable parameters (gradients don’t flow through them) but they DO change between calls (you update them as you process training batches). Flax handles this via variable collections — namespaces in the variable tree distinct from "params".

The convention: BatchNorm’s running statistics live in a "batch_stats" collection.

Variables vs params

self.param(name, init, shape) declares a trainable parameter in the "params" collection. Gradients flow through it; optimizers update it.

self.variable(collection_name, var_name, init_fn) declares a non-parameter variable in any collection you name. Gradients don’t flow through it; you update it manually inside __call__.

running_mean = self.variable(
    "batch_stats", "running_mean",
    lambda: jnp.zeros((d,))
)
# Read:  running_mean.value
# Write: running_mean.value = new_value

init / apply with collections

model.init(key, x, ...) returns ALL collections — usually {"params": {...}, "batch_stats": {...}}. Split them:

variables = model.init(key, x, use_running_average=False)
params = variables["params"]
batch_stats = variables["batch_stats"]

For apply, declare which collections are mutable. Anything not in mutable=... is read-only:

out, updated = model.apply(
    {"params": params, "batch_stats": batch_stats},
    x,
    use_running_average=False,
    mutable=["batch_stats"],          # batch_stats can be written
)
new_batch_stats = updated["batch_stats"]

apply returns (output, mutated_state) when mutable is non-empty.

The is_initializing() guard

During init, the forward pass runs but you don’t want to update running stats with the (single) example input. Flax provides self.is_initializing() that returns True only inside init():

if not self.is_initializing():
    running_mean.value = momentum * running_mean.value + (1 - momentum) * mu
    running_var.value  = momentum * running_var.value  + (1 - momentum) * var

Without this guard, init would try to write into running stats — but init’s purpose is just to allocate them.

Train vs eval branches

BatchNorm has TWO modes controlled by use_running_average:

  • use_running_average=False (training): use batch stats; update running stats.
  • use_running_average=True (eval): use running stats (no batch dependency).

Worked sketch

class MyBatchNorm(nn.Module):
    momentum: float = 0.9
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x, use_running_average: bool):
        d = x.shape[-1]
        gamma = self.param("gamma", nn.initializers.ones, (d,))
        beta = self.param("beta", nn.initializers.zeros, (d,))
        running_mean = self.variable(
            "batch_stats", "running_mean", lambda: jnp.zeros((d,))
        )
        running_var = self.variable(
            "batch_stats", "running_var", lambda: jnp.ones((d,))
        )
        if use_running_average:
            mu, var = running_mean.value, running_var.value
        else:
            mu = jnp.mean(x, axis=0)
            var = jnp.mean((x - mu) ** 2, axis=0)
            if not self.is_initializing():
                running_mean.value = self.momentum * running_mean.value + (1 - self.momentum) * mu
                running_var.value  = self.momentum * running_var.value  + (1 - self.momentum) * var
        x_hat = (x - mu) / jnp.sqrt(var + self.eps)
        return gamma * x_hat + beta

Statistics are over axis=0 (the batch axis). For 2-D (N, D) input this gives shape (D,) per-feature stats.

Common pitfalls

  • Forgetting is_initializing(): init errors because it tries to write to a variable while in init mode.
  • Forgetting mutable=: apply raises because it tries to write to a collection that wasn’t declared mutable.
  • Wrong axis for batch stats: BatchNorm normalizes across batch (axis=0), not across features (which would be LayerNorm).
  • Updating with batch stats during eval: pass use_running_average=True (and the eval-time output uses the EMA stats — no update needed).

Problem

Implement MyBatchNorm per the sketch above.

The function does the full init + apply in one go:

  1. Init with {"params": ...} derived from seed, use_running_average=False.
  2. Split out params and batch_stats from the returned variables dict.
  3. Apply with the test’s use_running_average flag and mutable=["batch_stats"].
  4. Return the output (flattened to 1-D for tests).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • use_running_average: float (0.0 or 1.0).

Output: 1-D flattened.

Hints

flax batchnorm mutable-state

Sign in to attempt this problem and view the solution.