medium primitives

NNX Implement LayerNorm

Why this matters

LayerNorm is the normalization of choice in transformers. Unlike BatchNorm it has NO running statistics — every forward pass computes its own mean and variance from the input alone. That makes it stateless in a way BatchNorm isn’t, and consequently trivial to implement in nnx: two nnx.Params and four lines of __call__.

Compare with Linen, where you’d still need @nn.compact and self.param(...). The math doesn’t change — the surrounding code does.

API: gain and offset

LayerNorm has two trainable parameters per feature:

  • gamma (a.k.a. scale, weight): per-feature gain. Init to ones.
  • beta (a.k.a. offset, bias): per-feature offset. Init to zeros.

Both have shape (D,) where D is the size of the last axis.

The forward computes mean and variance over the last axis, normalizes, then applies the affine transform:

mu = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
x_hat = (x - mu) / jnp.sqrt(var + eps)
return gamma * x_hat + beta

keepdims=True keeps mu and var with shape (..., 1) so they broadcast against x cleanly.

Why over the LAST axis?

LayerNorm normalizes per-token (or per-sample) over the feature axis. For a transformer with input shape (T, D), you want each of the T tokens to have unit-variance, zero-mean features — the statistics are computed within each row independently. The last axis (D) is where the features live by Flax convention, so axis=-1 is the right choice.

Contrast with BatchNorm (statistics over batch axis) and InstanceNorm (statistics over spatial axes only). The choice of axis is what distinguishes them — the rest of the math is the same.

Worked example

class MyLayerNorm(nnx.Module):
    def __init__(self, d, eps, rngs):
        # rngs unused — gamma and beta have deterministic init.
        self.gamma = nnx.Param(jnp.ones((d,)))
        self.beta = nnx.Param(jnp.zeros((d,)))
        self.eps = eps

    def __call__(self, x):
        mu = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
        x_hat = (x - mu) / jnp.sqrt(var + self.eps)
        return self.gamma * x_hat + self.beta

Notice rngs is unused. Both gamma and beta have deterministic init (ones, zeros) so no random key is needed. The convention is to accept rngs anyway, in case you later swap to a randomized init — keeps the construction signature consistent across layers.

Linen contrast

# Linen — for contrast.
class MyLayerNorm(nn.Module):
    eps: float = 1e-5
    @nn.compact
    def __call__(self, x):
        d = x.shape[-1]
        gamma = self.param("gamma", nn.initializers.ones, (d,))
        beta = self.param("beta", nn.initializers.zeros, (d,))
        mu = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.mean((x - mu) ** 2, axis=-1, keepdims=True)
        x_hat = (x - mu) / jnp.sqrt(var + self.eps)
        return gamma * x_hat + beta

# Init / apply ceremony:
model = MyLayerNorm()
params = model.init(jax.random.PRNGKey(0), x)
y = model.apply(params, x)

Same arithmetic, with init+apply overhead. nnx’s version skips both.

Common pitfalls

  • Wrong axis. axis=0 (batch axis) gives BatchNorm semantics, not LayerNorm. axis=-1 (feature axis) is right.
  • Forgetting keepdims=True. Without it, mu is shape (...) instead of (..., 1); broadcasting against x fails.
  • Missing eps. sqrt(var) divides by zero when the input is constant. sqrt(var + eps) saves you.
  • Initializing gamma to zeros. That zeros out the output before training; ones is the standard.

Problem

Write layernorm_forward(seed, x, eps):

  1. Define MyLayerNorm(nnx.Module) with self.gamma = nnx.Param(jnp.ones((d,))), self.beta = nnx.Param(jnp.zeros((d,))), and self.eps = eps (plain attribute).
  2. __call__ computes mean and variance over axis=-1 with keepdims=True, then gamma * (x - mu) / sqrt(var + eps) + beta.
  3. Build with nnx.Rngs(int(seed)) (rngs is unused but accepted), instantiate (d=x.shape[-1], eps=float(eps)), return model(x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • eps: float.

Output: same shape as x.

Hints

flax nnx layernorm reimplementation

Sign in to attempt this problem and view the solution.