medium primitives

Implement LayerNorm with γ/β

Why this matters

LayerNorm is the silent backbone of every Transformer (and most modern architectures). It stabilizes training by normalizing activations across the feature dimension at every position — making the optimization landscape smoother and gradients better-behaved.

Unlike BatchNorm, LayerNorm has no batch dependency: the statistics are computed PER-EXAMPLE across the feature axis. This makes it embarrassingly parallel, doesn’t break under varying batch sizes, and works in autoregressive settings where future tokens shouldn’t be visible in normalization.

Math

For input x of shape (..., D), normalize the LAST axis:

μ  = mean(x, axis=-1)            # per-position mean (broadcastable)
σ² = mean((x - μ)², axis=-1)     # per-position variance
x̂  = (x - μ) / √(σ² + ε)
y  = γ ⊙ x̂ + β

where γ and β are learned per-feature scale and shift, both shape (D,).

Initial values: γ = 1, β = 0 — so initial LayerNorm is identity-like after the normalization (it just scales to zero-mean, unit-variance and multiplies by 1).

ε is non-trivial

ε (a small constant, typically 1e-5 or 1e-6) prevents division by zero when x is constant along the feature axis. Without it, a constant-valued input produces NaN immediately.

Worked numerical example

x = jnp.array([1.0, 2.0, 3.0, 4.0])    # shape (4,)
eps = 1e-5

# mean = 2.5, var = mean((x-2.5)^2) = 1.25
# std = sqrt(1.25 + eps) ≈ 1.118
# x_hat = [(1-2.5)/1.118, ..., (4-2.5)/1.118] ≈ [-1.342, -0.447, 0.447, 1.342]

# With γ=1, β=0: y == x_hat.

Why mean and var across the LAST axis?

LayerNorm is “per-feature” in the sense that γ, β are learned per-feature — but the statistics are computed across the features. The result: each position is normalized to unit norm, then per-feature scale/shift is applied.

For a Transformer with (batch, seq, d_model) shapes, axis=-1 means “across d_model” — exactly what you want. Each (batch, seq) position is normalized independently.

Common pitfalls

  • var = jnp.var(x, axis=-1) (default ddof=0) is fine here — both var and the explicit mean((x-μ)^2) give the same answer. We use the explicit form to make the math obvious.
  • keepdims=False then divide: shapes break. Use keepdims=True or reshape μ, σ² back.
  • Forgetting ε: divides by zero on constant inputs. Always sqrt(var + eps), never sqrt(var) + eps.
  • Adding ε in the wrong place: sqrt(var + eps) (inside) is correct; sqrt(var) + eps (outside) is a different operation that doesn’t protect against zero variance.

Problem

Implement MyLayerNorm(eps):

  1. Read D = x.shape[-1].
  2. gamma = self.param("gamma", ones, (D,)).
  3. beta = self.param("beta", zeros, (D,)).
  4. Compute μ = mean(x, axis=-1, keepdims=True), σ² = mean((x-μ)², axis=-1, keepdims=True).
  5. Return gamma * (x - μ) / sqrt(σ² + eps) + beta.

Inputs:

  • seed: float (cast to int).
  • x: 1-D JAX array (shape (D,) for these tests).
  • eps: float.

Output: 1-D array (same shape as x).

Hints

flax layernorm self-param

Sign in to attempt this problem and view the solution.