We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Transformer Decoder Block (Pre-LN)
Why this matters
The decoder block is what makes encoder-decoder Transformers (translation, T5, BART) work. It has THREE sub-layers β one more than the encoder block β because the decoder must:
- Attend to its own past (causal self-attention).
- Read from the encoder output (cross-attention).
- Transform per-token (FFN).
Each sub-layer wraps with a residual + LayerNorm (Pre-LN style here). This is the canonical Vaswani et al. 2017 decoder layer.
The three sub-layers
x_q β LN β causal_self_attn(LN(x_q)) β +x_q (sub-layer 1)
β LN β cross_attn(LN(x_q), LN(x_kv)) β +x_q (sub-layer 2)
β LN β FFN(LN(x_q)) β +x_q (sub-layer 3)
β out
Sub-layer 2 is the key innovation: queries come from the decoder
(x_q), keys/values come from the encoder (x_kv). This is how
the decoder accesses the encoderβs representation.
Causal self-attention
Self-attention with a lower-triangular mask so position t only
attends to positions β€ t. Same as the encoder blockβs MHA, just
with mask=jnp.tril(jnp.ones((T_q, T_q))) passed in.
Without the causal mask, training would leak future tokens into
each prediction β see pos 23 (flax-mha-causal).
Cross-attention
Cross-attention is the two-input MHA call:
out = nn.MultiHeadDotProductAttention(...)(inputs_q, inputs_kv)
T_q and T_kv can differ β the decoder might be 4 tokens long
while the encoder fed 16. The score matrix is (T_q, T_kv), the
output is (T_q, D). Same query length in, same query length out.
No mask here β the decoder is free to read any encoder position.
Worked walk-through
With T_q=4, T_kv=3, D=8, H=2, d_ff=16:
-
Causal self-attn on
x_q:(4, 8) β (4, 8). Residual. -
Cross-attn: Q from updated
x_q, K/V fromx_kv.(4, 8)out. Residual. -
FFN per position:
Dense(16) β relu β Dense(8). Residual. -
Output
(4, 8)β same shape asx_q.
Three LayerNorms, three residuals, three sub-layers β count them.
Common pitfalls
- Forgetting the causal mask on the FIRST attention sub-layer. Thatβs the difference between training a decoder and training a bi-directional encoder.
- Putting a mask on the cross-attention: there is no causality across encoder/decoder β the encoder sequence is fully observed.
- Mixing up which input is Q vs KV: queries are always from the decoder side (the side being generated). K/V from the encoder.
- Skipping a residual: easy to forget the third one. All three sub-layers MUST have residuals for stable training.
Problem
Implement decoder_block_forward(seed, x_q, x_kv, num_heads, d_ff)
using a Pre-LN decoder block:
-
D = x_q.shape[-1]; T = x_q.shape[0]; mask = jnp.tril(jnp.ones((T, T))). -
Causal self-attn on
LN(x_q), residual. -
Cross-attn from
LN(x_q)queries toLN(x_kv)keys/values, residual. -
FFN on
LN(x_q):Dense(d_ff) β relu β Dense(D), residual. -
Return
x_q.reshape(-1).
Build a small nn.Module (DecoderBlock) inside @nn.compact.
Inputs:
-
seed: int. -
x_q: 2-D(T_q, D). -
x_kv: 2-D(T_kv, D). -
num_heads: int H.Ddivisible by H. -
d_ff: int FFN hidden dim.
Output: 1-D, the flattened (T_q, D) output.
Hints
Sign in to attempt this problem and view the solution.