hard primitives

Transformer Encoder Block (Pre-LN)

Why this matters

The encoder block is the canonical Transformer building unit. BERT, ViT, T5’s encoder, and dozens of other models stack many copies of this same recipe. Once you can write one block end-to-end, the rest of the architecture is “stack N of these and add embeddings.”

The block has two sub-layers:

  1. Multi-head self-attention — tokens look at each other.
  2. Position-wise feed-forward (FFN) — each token transforms its own vector independently.

Each sub-layer is wrapped with a residual connection and a LayerNorm. Two LayerNorms, two residuals, every block.

Pre-LN vs Post-LN

Two ways to wire the LayerNorm + residual:

  • Post-LN (original 2017): x = LN(x + sublayer(x)) — norm AFTER the residual. The math is cleaner, but training deep stacks needs careful warm-up; gradients can explode.
  • Pre-LN (modern, GPT-2 onwards): x = x + sublayer(LN(x)) — norm BEFORE the sublayer, residual untouched. Far more stable for deep models, default in essentially every modern Transformer.

This problem uses Pre-LN:

x → LN → MHA       → +x  → LN → FFN → +x → out
    \___ residual ___/      \__ residual __/

The FFN

The position-wise feed-forward network is two Dense layers with a nonlinearity between:

h = Dense(d_ff)(x)        # (T, D) → (T, d_ff)  — usually d_ff = 4·D
h = relu(h)
h = Dense(D)(h)           # (T, d_ff) → (T, D)

“Position-wise” because the same (D → d_ff → D) MLP is applied at every token position independently — no mixing across positions. Mixing happens in attention; mixing within a token happens in FFN.

Worked walk-through

With T=4, D=8, H=2, d_ff=16:

  1. x ∈ (4, 8).
  2. h = LN(x); h = MHA(h)(4, 8); x = x + h.
  3. h = LN(x); h = Dense(16)(h) → relu → Dense(8)(h); x = x + h.
  4. Return x.reshape(-1)(32,).

Both LayerNorms have their own learned (γ, β); both Dense layers have their own (W, b); the MHA has its own Q/K/V/O projections. Flax tracks them all under unique scope names automatically.

Common pitfalls

  • Forgetting a residual: if you write x = sublayer(LN(x)) without + x, you’ve lost the skip — training will be much harder.
  • Putting LN AFTER the sublayer in Pre-LN: that’s Post-LN; you lose the stability advantage.
  • Sharing one nn.LayerNorm instance for both spots: they should be DIFFERENT modules (different learned scales). Just call nn.LayerNorm() twice inside @nn.compact.
  • Wrong FFN sizing: Dense(d_ff) then Dense(D), in that order. Flipping them gives the wrong output dim.

Problem

Implement encoder_block_forward(seed, x, num_heads, d_ff) using a Pre-LN encoder block:

  1. D = x.shape[-1].
  2. h = LayerNorm()(x); h = MHA(h); x = x + h.
  3. h = LayerNorm()(x); h = Dense(d_ff)(h); h = relu(h); h = Dense(D)(h); x = x + h.
  4. Return x.reshape(-1).

Build a small nn.Module (e.g. EncoderBlock) and assemble inside @nn.compact. Init with jax.random.PRNGKey(seed), apply on x.

Inputs:

  • seed: int.
  • x: 2-D (T, D) input.
  • num_heads: int H. D must be divisible by H.
  • d_ff: int FFN hidden dim.

Output: 1-D, the flattened (T, D) output.

Hints

flax transformer encoder architecture

Sign in to attempt this problem and view the solution.