medium primitives

Implement RMSNorm (Modern LLM Norm)

Why this matters

RMSNorm is what LayerNorm became in modern LLMs (LLaMA, T5, PaLM, Gemma). It’s strictly simpler than LayerNorm — no mean subtraction, no bias — and empirically performs as well or better while being faster on accelerators.

The bet: LayerNorm’s mean-subtraction step is doing recentering work that the network can absorb into weights anyway, and its β bias has the same role as any other bias in the network. So drop both — keep just the variance scaling and the γ scale.

Math

Compared to LayerNorm:

LayerNorm:  x̂ = (x - μ) / sqrt(var + ε);   y = γ * x̂ + β
RMSNorm:    x̂ = x / sqrt(mean(x²) + ε);    y = γ * x̂

Two differences:

  1. No mean subtraction. Use mean(x²) directly instead of variance.
  2. No β bias. Just γ scale.

mean(x²) is the mean square — the squared RMS. So the divisor is sqrt(mean(x²) + ε) ≈ RMS of x (with epsilon for stability).

Initial γ = 1, so RMSNorm at init divides each row by its RMS — the output has unit RMS regardless of input scale.

Why it’s faster

Per element: LayerNorm does 4 reductions across the feature axis (mean twice — for centering and for variance — plus the affine transform). RMSNorm does 1 reduction (the mean square). On accelerators where reductions are expensive, this matters; on production LLM training, the norm is in every block on the critical path of the forward pass.

Why it works

The neural network has a bias term in every Dense layer right after the norm. So LayerNorm’s β is mostly redundant — the next layer can recreate any constant shift it wants. Same logic for the mean subtraction: if the network “wants” zero-centered features, the next Dense’s bias can do it.

Empirically the loss curves are basically identical. RMSNorm just costs less.

Worked example

x = jnp.array([1., 2., 3., 4.])
eps = 1e-6

# mean(x^2) = (1 + 4 + 9 + 16) / 4 = 7.5
# rms = sqrt(7.5 + eps) ≈ 2.7386
# x_hat = x / 2.7386 ≈ [0.365, 0.730, 1.095, 1.461]
# γ = 1 → y == x_hat

Common pitfalls

  • Subtracting the mean by reflex: it’s not LayerNorm. The whole point is: just divide by RMS.
  • Using β: there is no β. Add it back and you’re not implementing RMSNorm anymore.
  • sqrt(mean(x)) instead of sqrt(mean(x²)): easy typo, big bug.
  • ε in the wrong place: sqrt(mean(x²) + ε) (inside) — like LayerNorm.

Problem

Implement MyRMSNorm(eps):

  1. Read D = x.shape[-1].
  2. gamma = self.param("gamma", ones, (D,)) — note: no β.
  3. Compute ms = jnp.mean(x ** 2, axis=-1, keepdims=True).
  4. Return gamma * x / jnp.sqrt(ms + eps).

Inputs:

  • seed: float (cast to int).
  • x: 1-D JAX array.
  • eps: float.

Output: 1-D array (same shape).

Hints

flax rmsnorm llm

Sign in to attempt this problem and view the solution.