medium primitives

Pre-LN vs Post-LN Residual Pattern

Why this matters

The original Transformer (Vaswani et al., 2017) puts LayerNorm AFTER the residual addition — Post-LN. Every modern Transformer (GPT-2, GPT-3, LLaMA, PaLM, ViT, Gemma…) puts it BEFORE the sublayer — Pre-LN. The math difference is one line; the training stability difference is enormous.

Pre-LN is now the default. But the choice has consequences worth understanding — including why the original Transformer needed learning-rate warmup and the modern ones don’t.

The two patterns

Let f(·) be a sublayer (attention or FFN).

  • Post-LN (original):

    y = LayerNorm(x + f(x))

    Compute the sublayer on x, ADD the residual, THEN normalise.

  • Pre-LN (modern):

    y = x + f(LayerNorm(x))

    Normalise FIRST, run the sublayer on the normalised input, THEN add the (un-normalised) residual.

The difference is which side of the residual the LayerNorm sits on.

Why Pre-LN trains more stably

In Pre-LN, the residual stream x is NEVER normalised — every sublayer just adds to it. So the gradient of the loss w.r.t. early layers’ inputs has a direct path through every residual addition, no normalisations in between to scale the gradient. Gradient norms stay roughly constant with depth.

In Post-LN, the residual stream goes through a LayerNorm at every layer. Gradients passing back through LayerNorm get scaled by that layer’s effective Jacobian — multiplied across many layers, they can explode or vanish. The fix in the original paper is learning-rate warmup: start small, grow over the first ~10k steps.

Modern deep stacks (24+ layers, soon 100+) basically can’t be trained stably with Post-LN without either careful warmup or other tricks (DeepNorm, etc.).

Worked walk-through

With D=8, d_ff=16, sublayer = Dense(d_ff) → relu → Dense(D):

Pre-LN:

  1. h = LayerNorm(x) → normalised input.
  2. h = Dense(16)(h), h = relu(h), h = Dense(8)(h) → sublayer output.
  3. Return x + h — residual has the RAW x.

Post-LN:

  1. h = Dense(16)(x), h = relu(h), h = Dense(8)(h) → sublayer output.
  2. s = x + h → residual sum (un-normalised).
  3. Return LayerNorm(s) — normalises the sum.

The output differs because in Post-LN, LayerNorm sees x + h (potentially large/skewed), while in Pre-LN the LayerNorm sees the raw x and the residual stays raw.

Common pitfalls

  • Putting LN inside the residual (x + LayerNorm(f(x)) or similar): that’s neither Pre-LN nor Post-LN — it’s a half-baked hybrid that loses the benefits of both.
  • Forgetting that param structure differs: Pre-LN and Post-LN use the SAME number of params (one LN, two Dense). A different mode gives different OUTPUTS but identical param shapes — useful for swapping at training-time.
  • Branching on a JAX traced value: this problem keeps mode as a Python float (passed in from outside), so a Python if works. Don’t try to branch on jnp.where for the architecture.

Problem

Implement pre_vs_post_ln(seed, x, mode, d_ff):

  1. pre = float(mode) >= 0.5. Use a Python if pre: to pick the path.
  2. Sublayer is the FFN: Dense(d_ff) → relu → Dense(D) where D = x.shape[-1].
  3. Pre-LN: return x + sublayer(LayerNorm(x)).
  4. Post-LN: return LayerNorm(x + sublayer(x)).
  5. Flatten output.

Build a small nn.Module with a pre: bool field and branch in @nn.compact.

Inputs:

  • seed: int.
  • x: 2-D (T, D).
  • mode: float. >= 0.5 → Pre-LN; else Post-LN.
  • d_ff: int FFN hidden dim.

Output: 1-D, flattened (T, D).

Hints

flax transformer layernorm residual

Sign in to attempt this problem and view the solution.