We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement RMSNorm
Why this matters
RMSNorm is the modern LLM normalization. LLaMA, T5, PaLM, Mistral, Gemma — all use RMSNorm instead of LayerNorm. The 2019 RMSNorm paper showed that the mean-centering step in LayerNorm is mostly redundant once the model is wide enough; dropping it yields a strictly simpler layer that’s faster, has fewer parameters (no bias), and trains just as well.
Compared to LayerNorm:
-
LayerNorm:
gamma * (x - mean) / sqrt(var + eps) + beta. Bothgammaandbeta, mean-centered. -
RMSNorm:
gamma * x / sqrt(mean(x^2) + eps). Justgamma, no mean subtraction, nobeta.
One parameter instead of two, no centering step, same convergence in practice for transformers. Strictly simpler.
API
class MyRMSNorm(nnx.Module):
def __init__(self, d, eps, rngs):
self.gamma = nnx.Param(jnp.ones((d,)))
self.eps = eps
def __call__(self, x):
ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
x_hat = x / jnp.sqrt(ms + self.eps)
return self.gamma * x_hat
ms is the mean of squares (as opposed to the LayerNorm var, which
is mean((x - mu) ** 2)). When the input is already mean-zero (as it
typically is after residual connections in transformers), ms == var,
so dropping the centering step changes nothing.
Why no beta?
Empirically, the bias term in LayerNorm contributes very little once the network is large. Removing it saves parameters with no measurable quality loss; later models standardized on the simpler form.
Worked example
rngs = nnx.Rngs(0)
model = MyRMSNorm(d=4, eps=1e-5, rngs=rngs)
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = model(x)
# ms = mean([1, 4, 9, 16]) = 7.5
# y[i] = x[i] / sqrt(7.5) (gamma=1)
Linen contrast
Linen has nn.RMSNorm, but if you wrote it from scratch:
class MyRMSNorm(nn.Module):
eps: float = 1e-6
@nn.compact
def __call__(self, x):
d = x.shape[-1]
gamma = self.param("scale", nn.initializers.ones, (d,))
ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
return gamma * x / jnp.sqrt(ms + self.eps)
Same six lines of math; nnx removes the init+apply boilerplate.
Common pitfalls
-
Computing
varinstead ofms.var = mean((x - mu) ** 2)is LayerNorm’s denominator. RMSNorm usesms = mean(x ** 2)— no centering. -
Adding a
beta. RMSNorm has no offset parameter; including it would be a non-standard variant. -
Initializing
gammato zeros. Outputs would be zero; ones is standard. -
Forgetting
keepdims=True.msneeds shape(..., 1)to broadcast againstx.
Problem
Write rmsnorm_forward(seed, x, eps):
-
Define
MyRMSNorm(nnx.Module)with onennx.Paramgammashape(d,)initialized to ones, andself.eps = eps. -
__call__:ms = jnp.mean(x ** 2, axis=-1, keepdims=True),x_hat = x / jnp.sqrt(ms + self.eps), returnself.gamma * x_hat. -
Build with
nnx.Rngs(int(seed))(unused), instantiate (d=x.shape[-1],eps=float(eps)), returnmodel(x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
eps: float.
Output: same shape as x.
Hints
Sign in to attempt this problem and view the solution.