We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement BatchNorm with Mutable batch_stats
Why this matters
BatchNorm is the most-cited normalization technique in modern DL — and the most subtle to implement in Flax because it has a piece every other layer doesn’t: mutable, non-parameter state.
The running mean and running variance are NOT trainable parameters
(gradients don’t flow through them) but they DO change between calls
(you update them as you process training batches). Flax handles this
via variable collections — namespaces in the variable tree
distinct from "params".
The convention: BatchNorm’s running statistics live in a "batch_stats"
collection.
Variables vs params
self.param(name, init, shape) declares a trainable parameter in
the "params" collection. Gradients flow through it; optimizers update it.
self.variable(collection_name, var_name, init_fn) declares a
non-parameter variable in any collection you name. Gradients don’t
flow through it; you update it manually inside __call__.
running_mean = self.variable(
"batch_stats", "running_mean",
lambda: jnp.zeros((d,))
)
# Read: running_mean.value
# Write: running_mean.value = new_value
init / apply with collections
model.init(key, x, ...) returns ALL collections — usually
{"params": {...}, "batch_stats": {...}}. Split them:
variables = model.init(key, x, use_running_average=False)
params = variables["params"]
batch_stats = variables["batch_stats"]
For apply, declare which collections are mutable. Anything not in
mutable=... is read-only:
out, updated = model.apply(
{"params": params, "batch_stats": batch_stats},
x,
use_running_average=False,
mutable=["batch_stats"], # batch_stats can be written
)
new_batch_stats = updated["batch_stats"]
apply returns (output, mutated_state) when mutable is non-empty.
The is_initializing() guard
During init, the forward pass runs but you don’t want to update running
stats with the (single) example input. Flax provides self.is_initializing()
that returns True only inside init():
if not self.is_initializing():
running_mean.value = momentum * running_mean.value + (1 - momentum) * mu
running_var.value = momentum * running_var.value + (1 - momentum) * var
Without this guard, init would try to write into running stats — but
init’s purpose is just to allocate them.
Train vs eval branches
BatchNorm has TWO modes controlled by use_running_average:
-
use_running_average=False(training): use batch stats; update running stats. -
use_running_average=True(eval): use running stats (no batch dependency).
Worked sketch
class MyBatchNorm(nn.Module):
momentum: float = 0.9
eps: float = 1e-5
@nn.compact
def __call__(self, x, use_running_average: bool):
d = x.shape[-1]
gamma = self.param("gamma", nn.initializers.ones, (d,))
beta = self.param("beta", nn.initializers.zeros, (d,))
running_mean = self.variable(
"batch_stats", "running_mean", lambda: jnp.zeros((d,))
)
running_var = self.variable(
"batch_stats", "running_var", lambda: jnp.ones((d,))
)
if use_running_average:
mu, var = running_mean.value, running_var.value
else:
mu = jnp.mean(x, axis=0)
var = jnp.mean((x - mu) ** 2, axis=0)
if not self.is_initializing():
running_mean.value = self.momentum * running_mean.value + (1 - self.momentum) * mu
running_var.value = self.momentum * running_var.value + (1 - self.momentum) * var
x_hat = (x - mu) / jnp.sqrt(var + self.eps)
return gamma * x_hat + beta
Statistics are over axis=0 (the batch axis). For 2-D (N, D) input
this gives shape (D,) per-feature stats.
Common pitfalls
-
Forgetting
is_initializing():initerrors because it tries to write to a variable while in init mode. -
Forgetting
mutable=:applyraises because it tries to write to a collection that wasn’t declared mutable. -
Wrong axis for batch stats: BatchNorm normalizes across batch
(
axis=0), not across features (which would be LayerNorm). -
Updating with batch stats during eval: pass
use_running_average=True(and the eval-time output uses the EMA stats — no update needed).
Problem
Implement MyBatchNorm per the sketch above.
The function does the full init + apply in one go:
-
Init with
{"params": ...}derived fromseed,use_running_average=False. -
Split out
paramsandbatch_statsfrom the returned variables dict. -
Apply with the test’s
use_running_averageflag andmutable=["batch_stats"]. - Return the output (flattened to 1-D for tests).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
use_running_average: float (0.0 or 1.0).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.