medium primitives

NNX Implement RMSNorm

Why this matters

RMSNorm is the modern LLM normalization. LLaMA, T5, PaLM, Mistral, Gemma — all use RMSNorm instead of LayerNorm. The 2019 RMSNorm paper showed that the mean-centering step in LayerNorm is mostly redundant once the model is wide enough; dropping it yields a strictly simpler layer that’s faster, has fewer parameters (no bias), and trains just as well.

Compared to LayerNorm:

  • LayerNorm: gamma * (x - mean) / sqrt(var + eps) + beta. Both gamma and beta, mean-centered.
  • RMSNorm: gamma * x / sqrt(mean(x^2) + eps). Just gamma, no mean subtraction, no beta.

One parameter instead of two, no centering step, same convergence in practice for transformers. Strictly simpler.

API

class MyRMSNorm(nnx.Module):
    def __init__(self, d, eps, rngs):
        self.gamma = nnx.Param(jnp.ones((d,)))
        self.eps = eps

    def __call__(self, x):
        ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
        x_hat = x / jnp.sqrt(ms + self.eps)
        return self.gamma * x_hat

ms is the mean of squares (as opposed to the LayerNorm var, which is mean((x - mu) ** 2)). When the input is already mean-zero (as it typically is after residual connections in transformers), ms == var, so dropping the centering step changes nothing.

Why no beta?

Empirically, the bias term in LayerNorm contributes very little once the network is large. Removing it saves parameters with no measurable quality loss; later models standardized on the simpler form.

Worked example

rngs = nnx.Rngs(0)
model = MyRMSNorm(d=4, eps=1e-5, rngs=rngs)
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = model(x)
# ms = mean([1, 4, 9, 16]) = 7.5
# y[i] = x[i] / sqrt(7.5)  (gamma=1)

Linen contrast

Linen has nn.RMSNorm, but if you wrote it from scratch:

class MyRMSNorm(nn.Module):
    eps: float = 1e-6
    @nn.compact
    def __call__(self, x):
        d = x.shape[-1]
        gamma = self.param("scale", nn.initializers.ones, (d,))
        ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
        return gamma * x / jnp.sqrt(ms + self.eps)

Same six lines of math; nnx removes the init+apply boilerplate.

Common pitfalls

  • Computing var instead of ms. var = mean((x - mu) ** 2) is LayerNorm’s denominator. RMSNorm uses ms = mean(x ** 2) — no centering.
  • Adding a beta. RMSNorm has no offset parameter; including it would be a non-standard variant.
  • Initializing gamma to zeros. Outputs would be zero; ones is standard.
  • Forgetting keepdims=True. ms needs shape (..., 1) to broadcast against x.

Problem

Write rmsnorm_forward(seed, x, eps):

  1. Define MyRMSNorm(nnx.Module) with one nnx.Param gamma shape (d,) initialized to ones, and self.eps = eps.
  2. __call__: ms = jnp.mean(x ** 2, axis=-1, keepdims=True), x_hat = x / jnp.sqrt(ms + self.eps), return self.gamma * x_hat.
  3. Build with nnx.Rngs(int(seed)) (unused), instantiate (d=x.shape[-1], eps=float(eps)), return model(x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • eps: float.

Output: same shape as x.

Hints

flax nnx rmsnorm reimplementation

Sign in to attempt this problem and view the solution.