We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement LayerNorm
Why this matters
LayerNorm is the normalization of choice in transformers. Unlike
BatchNorm it has NO running statistics — every forward pass computes
its own mean and variance from the input alone. That makes it stateless
in a way BatchNorm isn’t, and consequently trivial to implement in
nnx: two nnx.Params and four lines of __call__.
Compare with Linen, where you’d still need @nn.compact and
self.param(...). The math doesn’t change — the surrounding code does.
API: gain and offset
LayerNorm has two trainable parameters per feature:
-
gamma(a.k.a.scale,weight): per-feature gain. Init to ones. -
beta(a.k.a.offset,bias): per-feature offset. Init to zeros.
Both have shape (D,) where D is the size of the last axis.
The forward computes mean and variance over the last axis, normalizes, then applies the affine transform:
mu = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
x_hat = (x - mu) / jnp.sqrt(var + eps)
return gamma * x_hat + beta
keepdims=True keeps mu and var with shape (..., 1) so they
broadcast against x cleanly.
Why over the LAST axis?
LayerNorm normalizes per-token (or per-sample) over the feature axis.
For a transformer with input shape (T, D), you want each of the T
tokens to have unit-variance, zero-mean features — the statistics are
computed within each row independently. The last axis (D) is where
the features live by Flax convention, so axis=-1 is the right choice.
Contrast with BatchNorm (statistics over batch axis) and InstanceNorm (statistics over spatial axes only). The choice of axis is what distinguishes them — the rest of the math is the same.
Worked example
class MyLayerNorm(nnx.Module):
def __init__(self, d, eps, rngs):
# rngs unused — gamma and beta have deterministic init.
self.gamma = nnx.Param(jnp.ones((d,)))
self.beta = nnx.Param(jnp.zeros((d,)))
self.eps = eps
def __call__(self, x):
mu = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
x_hat = (x - mu) / jnp.sqrt(var + self.eps)
return self.gamma * x_hat + self.beta
Notice rngs is unused. Both gamma and beta have deterministic
init (ones, zeros) so no random key is needed. The convention is to
accept rngs anyway, in case you later swap to a randomized init —
keeps the construction signature consistent across layers.
Linen contrast
# Linen — for contrast.
class MyLayerNorm(nn.Module):
eps: float = 1e-5
@nn.compact
def __call__(self, x):
d = x.shape[-1]
gamma = self.param("gamma", nn.initializers.ones, (d,))
beta = self.param("beta", nn.initializers.zeros, (d,))
mu = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
x_hat = (x - mu) / jnp.sqrt(var + self.eps)
return gamma * x_hat + beta
# Init / apply ceremony:
model = MyLayerNorm()
params = model.init(jax.random.PRNGKey(0), x)
y = model.apply(params, x)
Same arithmetic, with init+apply overhead. nnx’s version skips both.
Common pitfalls
-
Wrong axis.
axis=0(batch axis) gives BatchNorm semantics, not LayerNorm.axis=-1(feature axis) is right. -
Forgetting
keepdims=True. Without it,muis shape(...)instead of(..., 1); broadcasting againstxfails. -
Missing
eps.sqrt(var)divides by zero when the input is constant.sqrt(var + eps)saves you. -
Initializing
gammato zeros. That zeros out the output before training; ones is the standard.
Problem
Write layernorm_forward(seed, x, eps):
-
Define
MyLayerNorm(nnx.Module)withself.gamma = nnx.Param(jnp.ones((d,))),self.beta = nnx.Param(jnp.zeros((d,))), andself.eps = eps(plain attribute). -
__call__computes mean and variance overaxis=-1withkeepdims=True, thengamma * (x - mu) / sqrt(var + eps) + beta. -
Build with
nnx.Rngs(int(seed))(rngs is unused but accepted), instantiate (d=x.shape[-1],eps=float(eps)), returnmodel(x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
eps: float.
Output: same shape as x.
Hints
Sign in to attempt this problem and view the solution.