We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Transformer Encoder Block (Pre-LN)
Why this matters
The encoder block is the canonical Transformer building unit. BERT,
ViT, T5’s encoder, and dozens of other models stack many copies of
this same recipe. Once you can write one block end-to-end, the rest
of the architecture is “stack N of these and add embeddings.”
The block has two sub-layers:
- Multi-head self-attention — tokens look at each other.
- Position-wise feed-forward (FFN) — each token transforms its own vector independently.
Each sub-layer is wrapped with a residual connection and a LayerNorm. Two LayerNorms, two residuals, every block.
Pre-LN vs Post-LN
Two ways to wire the LayerNorm + residual:
-
Post-LN (original 2017):
x = LN(x + sublayer(x))— norm AFTER the residual. The math is cleaner, but training deep stacks needs careful warm-up; gradients can explode. -
Pre-LN (modern, GPT-2 onwards):
x = x + sublayer(LN(x))— norm BEFORE the sublayer, residual untouched. Far more stable for deep models, default in essentially every modern Transformer.
This problem uses Pre-LN:
x → LN → MHA → +x → LN → FFN → +x → out
\___ residual ___/ \__ residual __/
The FFN
The position-wise feed-forward network is two Dense layers with a nonlinearity between:
h = Dense(d_ff)(x) # (T, D) → (T, d_ff) — usually d_ff = 4·D
h = relu(h)
h = Dense(D)(h) # (T, d_ff) → (T, D)
“Position-wise” because the same (D → d_ff → D) MLP is applied at
every token position independently — no mixing across positions.
Mixing happens in attention; mixing within a token happens in FFN.
Worked walk-through
With T=4, D=8, H=2, d_ff=16:
-
x ∈ (4, 8). -
h = LN(x);h = MHA(h)→(4, 8);x = x + h. -
h = LN(x);h = Dense(16)(h) → relu → Dense(8)(h);x = x + h. -
Return
x.reshape(-1)→(32,).
Both LayerNorms have their own learned (γ, β); both Dense layers
have their own (W, b); the MHA has its own Q/K/V/O projections.
Flax tracks them all under unique scope names automatically.
Common pitfalls
-
Forgetting a residual: if you write
x = sublayer(LN(x))without+ x, you’ve lost the skip — training will be much harder. - Putting LN AFTER the sublayer in Pre-LN: that’s Post-LN; you lose the stability advantage.
-
Sharing one
nn.LayerNorminstance for both spots: they should be DIFFERENT modules (different learned scales). Just callnn.LayerNorm()twice inside@nn.compact. -
Wrong FFN sizing:
Dense(d_ff)thenDense(D), in that order. Flipping them gives the wrong output dim.
Problem
Implement encoder_block_forward(seed, x, num_heads, d_ff) using a
Pre-LN encoder block:
-
D = x.shape[-1]. -
h = LayerNorm()(x); h = MHA(h); x = x + h. -
h = LayerNorm()(x); h = Dense(d_ff)(h); h = relu(h); h = Dense(D)(h); x = x + h. -
Return
x.reshape(-1).
Build a small nn.Module (e.g. EncoderBlock) and assemble inside
@nn.compact. Init with jax.random.PRNGKey(seed), apply on x.
Inputs:
-
seed: int. -
x: 2-D(T, D)input. -
num_heads: int H.Dmust be divisible byH. -
d_ff: int FFN hidden dim.
Output: 1-D, the flattened (T, D) output.
Hints
Sign in to attempt this problem and view the solution.