We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement RMSNorm (Modern LLM Norm)
Why this matters
RMSNorm is what LayerNorm became in modern LLMs (LLaMA, T5, PaLM, Gemma). It’s strictly simpler than LayerNorm — no mean subtraction, no bias — and empirically performs as well or better while being faster on accelerators.
The bet: LayerNorm’s mean-subtraction step is doing recentering work that the network can absorb into weights anyway, and its β bias has the same role as any other bias in the network. So drop both — keep just the variance scaling and the γ scale.
Math
Compared to LayerNorm:
LayerNorm: x̂ = (x - μ) / sqrt(var + ε); y = γ * x̂ + β
RMSNorm: x̂ = x / sqrt(mean(x²) + ε); y = γ * x̂
Two differences:
-
No mean subtraction. Use
mean(x²)directly instead of variance. - No β bias. Just γ scale.
mean(x²) is the mean square — the squared RMS. So the divisor is
sqrt(mean(x²) + ε) ≈ RMS of x (with epsilon for stability).
Initial γ = 1, so RMSNorm at init divides each row by its RMS — the output has unit RMS regardless of input scale.
Why it’s faster
Per element: LayerNorm does 4 reductions across the feature axis (mean twice — for centering and for variance — plus the affine transform). RMSNorm does 1 reduction (the mean square). On accelerators where reductions are expensive, this matters; on production LLM training, the norm is in every block on the critical path of the forward pass.
Why it works
The neural network has a bias term in every Dense layer right after the norm. So LayerNorm’s β is mostly redundant — the next layer can recreate any constant shift it wants. Same logic for the mean subtraction: if the network “wants” zero-centered features, the next Dense’s bias can do it.
Empirically the loss curves are basically identical. RMSNorm just costs less.
Worked example
x = jnp.array([1., 2., 3., 4.])
eps = 1e-6
# mean(x^2) = (1 + 4 + 9 + 16) / 4 = 7.5
# rms = sqrt(7.5 + eps) ≈ 2.7386
# x_hat = x / 2.7386 ≈ [0.365, 0.730, 1.095, 1.461]
# γ = 1 → y == x_hat
Common pitfalls
- Subtracting the mean by reflex: it’s not LayerNorm. The whole point is: just divide by RMS.
- Using β: there is no β. Add it back and you’re not implementing RMSNorm anymore.
-
sqrt(mean(x))instead ofsqrt(mean(x²)): easy typo, big bug. -
εin the wrong place:sqrt(mean(x²) + ε)(inside) — like LayerNorm.
Problem
Implement MyRMSNorm(eps):
-
Read
D = x.shape[-1]. -
gamma = self.param("gamma", ones, (D,))— note: no β. -
Compute
ms = jnp.mean(x ** 2, axis=-1, keepdims=True). -
Return
gamma * x / jnp.sqrt(ms + eps).
Inputs:
-
seed: float (cast to int). -
x: 1-D JAX array. -
eps: float.
Output: 1-D array (same shape).
Hints
Sign in to attempt this problem and view the solution.