We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Implement LayerNorm with γ/β
Why this matters
LayerNorm is the silent backbone of every Transformer (and most modern architectures). It stabilizes training by normalizing activations across the feature dimension at every position — making the optimization landscape smoother and gradients better-behaved.
Unlike BatchNorm, LayerNorm has no batch dependency: the statistics are computed PER-EXAMPLE across the feature axis. This makes it embarrassingly parallel, doesn’t break under varying batch sizes, and works in autoregressive settings where future tokens shouldn’t be visible in normalization.
Math
For input x of shape (..., D), normalize the LAST axis:
μ = mean(x, axis=-1) # per-position mean (broadcastable)
σ² = mean((x - μ)², axis=-1) # per-position variance
x̂ = (x - μ) / √(σ² + ε)
y = γ ⊙ x̂ + β
where γ and β are learned per-feature scale and shift, both shape (D,).
Initial values: γ = 1, β = 0 — so initial LayerNorm is identity-like
after the normalization (it just scales to zero-mean, unit-variance and
multiplies by 1).
ε is non-trivial
ε (a small constant, typically 1e-5 or 1e-6) prevents division by
zero when x is constant along the feature axis. Without it, a
constant-valued input produces NaN immediately.
Worked numerical example
x = jnp.array([1.0, 2.0, 3.0, 4.0]) # shape (4,)
eps = 1e-5
# mean = 2.5, var = mean((x-2.5)^2) = 1.25
# std = sqrt(1.25 + eps) ≈ 1.118
# x_hat = [(1-2.5)/1.118, ..., (4-2.5)/1.118] ≈ [-1.342, -0.447, 0.447, 1.342]
# With γ=1, β=0: y == x_hat.
Why mean and var across the LAST axis?
LayerNorm is “per-feature” in the sense that γ, β are learned per-feature — but the statistics are computed across the features. The result: each position is normalized to unit norm, then per-feature scale/shift is applied.
For a Transformer with (batch, seq, d_model) shapes, axis=-1 means
“across d_model” — exactly what you want. Each (batch, seq) position
is normalized independently.
Common pitfalls
-
var = jnp.var(x, axis=-1)(defaultddof=0) is fine here — bothvarand the explicitmean((x-μ)^2)give the same answer. We use the explicit form to make the math obvious. -
keepdims=Falsethen divide: shapes break. Usekeepdims=Trueor reshapeμ,σ²back. -
Forgetting ε: divides by zero on constant inputs. Always
sqrt(var + eps), neversqrt(var) + eps. -
Adding ε in the wrong place:
sqrt(var + eps)(inside) is correct;sqrt(var) + eps(outside) is a different operation that doesn’t protect against zero variance.
Problem
Implement MyLayerNorm(eps):
-
Read
D = x.shape[-1]. -
gamma = self.param("gamma", ones, (D,)). -
beta = self.param("beta", zeros, (D,)). -
Compute
μ = mean(x, axis=-1, keepdims=True),σ² = mean((x-μ)², axis=-1, keepdims=True). -
Return
gamma * (x - μ) / sqrt(σ² + eps) + beta.
Inputs:
-
seed: float (cast to int). -
x: 1-D JAX array (shape(D,)for these tests). -
eps: float.
Output: 1-D array (same shape as x).
Hints
Sign in to attempt this problem and view the solution.