We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Pre-LN vs Post-LN Residual Pattern
Why this matters
The original Transformer (Vaswani et al., 2017) puts LayerNorm AFTER the residual addition — Post-LN. Every modern Transformer (GPT-2, GPT-3, LLaMA, PaLM, ViT, Gemma…) puts it BEFORE the sublayer — Pre-LN. The math difference is one line; the training stability difference is enormous.
Pre-LN is now the default. But the choice has consequences worth understanding — including why the original Transformer needed learning-rate warmup and the modern ones don’t.
The two patterns
Let f(·) be a sublayer (attention or FFN).
-
Post-LN (original):
y = LayerNorm(x + f(x))Compute the sublayer on
x, ADD the residual, THEN normalise. -
Pre-LN (modern):
y = x + f(LayerNorm(x))Normalise FIRST, run the sublayer on the normalised input, THEN add the (un-normalised) residual.
The difference is which side of the residual the LayerNorm sits on.
Why Pre-LN trains more stably
In Pre-LN, the residual stream x is NEVER normalised — every
sublayer just adds to it. So the gradient of the loss w.r.t. early
layers’ inputs has a direct path through every residual addition,
no normalisations in between to scale the gradient. Gradient norms
stay roughly constant with depth.
In Post-LN, the residual stream goes through a LayerNorm at every
layer. Gradients passing back through LayerNorm get scaled by
that layer’s effective Jacobian — multiplied across many layers,
they can explode or vanish. The fix in the original paper is
learning-rate warmup: start small, grow over the first ~10k steps.
Modern deep stacks (24+ layers, soon 100+) basically can’t be trained stably with Post-LN without either careful warmup or other tricks (DeepNorm, etc.).
Worked walk-through
With D=8, d_ff=16, sublayer = Dense(d_ff) → relu → Dense(D):
Pre-LN:
-
h = LayerNorm(x)→ normalised input. -
h = Dense(16)(h),h = relu(h),h = Dense(8)(h)→ sublayer output. -
Return
x + h— residual has the RAWx.
Post-LN:
-
h = Dense(16)(x),h = relu(h),h = Dense(8)(h)→ sublayer output. -
s = x + h→ residual sum (un-normalised). -
Return
LayerNorm(s)— normalises the sum.
The output differs because in Post-LN, LayerNorm sees x + h
(potentially large/skewed), while in Pre-LN the LayerNorm sees the
raw x and the residual stays raw.
Common pitfalls
-
Putting LN inside the residual (
x + LayerNorm(f(x))or similar): that’s neither Pre-LN nor Post-LN — it’s a half-baked hybrid that loses the benefits of both. -
Forgetting that param structure differs: Pre-LN and Post-LN
use the SAME number of params (one LN, two Dense). A different
modegives different OUTPUTS but identical param shapes — useful for swapping at training-time. -
Branching on a JAX traced value: this problem keeps
modeas a Python float (passed in from outside), so a Pythonifworks. Don’t try to branch onjnp.wherefor the architecture.
Problem
Implement pre_vs_post_ln(seed, x, mode, d_ff):
-
pre = float(mode) >= 0.5. Use a Pythonif pre:to pick the path. -
Sublayer is the FFN:
Dense(d_ff) → relu → Dense(D)whereD = x.shape[-1]. -
Pre-LN:
return x + sublayer(LayerNorm(x)). -
Post-LN:
return LayerNorm(x + sublayer(x)). - Flatten output.
Build a small nn.Module with a pre: bool field and branch in
@nn.compact.
Inputs:
-
seed: int. -
x: 2-D(T, D). -
mode: float.>= 0.5→ Pre-LN; else Post-LN. -
d_ff: int FFN hidden dim.
Output: 1-D, flattened (T, D).
Hints
Sign in to attempt this problem and view the solution.