hard primitives

Transformer Decoder Block (Pre-LN)

Why this matters

The decoder block is what makes encoder-decoder Transformers (translation, T5, BART) work. It has THREE sub-layers β€” one more than the encoder block β€” because the decoder must:

  1. Attend to its own past (causal self-attention).
  2. Read from the encoder output (cross-attention).
  3. Transform per-token (FFN).

Each sub-layer wraps with a residual + LayerNorm (Pre-LN style here). This is the canonical Vaswani et al. 2017 decoder layer.

The three sub-layers

x_q β†’ LN β†’ causal_self_attn(LN(x_q))     β†’ +x_q  (sub-layer 1)
    β†’ LN β†’ cross_attn(LN(x_q), LN(x_kv)) β†’ +x_q  (sub-layer 2)
    β†’ LN β†’ FFN(LN(x_q))                  β†’ +x_q  (sub-layer 3)
    β†’ out

Sub-layer 2 is the key innovation: queries come from the decoder (x_q), keys/values come from the encoder (x_kv). This is how the decoder accesses the encoder’s representation.

Causal self-attention

Self-attention with a lower-triangular mask so position t only attends to positions ≀ t. Same as the encoder block’s MHA, just with mask=jnp.tril(jnp.ones((T_q, T_q))) passed in.

Without the causal mask, training would leak future tokens into each prediction β€” see pos 23 (flax-mha-causal).

Cross-attention

Cross-attention is the two-input MHA call:

out = nn.MultiHeadDotProductAttention(...)(inputs_q, inputs_kv)

T_q and T_kv can differ β€” the decoder might be 4 tokens long while the encoder fed 16. The score matrix is (T_q, T_kv), the output is (T_q, D). Same query length in, same query length out. No mask here β€” the decoder is free to read any encoder position.

Worked walk-through

With T_q=4, T_kv=3, D=8, H=2, d_ff=16:

  1. Causal self-attn on x_q: (4, 8) β†’ (4, 8). Residual.
  2. Cross-attn: Q from updated x_q, K/V from x_kv. (4, 8) out. Residual.
  3. FFN per position: Dense(16) β†’ relu β†’ Dense(8). Residual.
  4. Output (4, 8) β€” same shape as x_q.

Three LayerNorms, three residuals, three sub-layers β€” count them.

Common pitfalls

  • Forgetting the causal mask on the FIRST attention sub-layer. That’s the difference between training a decoder and training a bi-directional encoder.
  • Putting a mask on the cross-attention: there is no causality across encoder/decoder β€” the encoder sequence is fully observed.
  • Mixing up which input is Q vs KV: queries are always from the decoder side (the side being generated). K/V from the encoder.
  • Skipping a residual: easy to forget the third one. All three sub-layers MUST have residuals for stable training.

Problem

Implement decoder_block_forward(seed, x_q, x_kv, num_heads, d_ff) using a Pre-LN decoder block:

  1. D = x_q.shape[-1]; T = x_q.shape[0]; mask = jnp.tril(jnp.ones((T, T))).
  2. Causal self-attn on LN(x_q), residual.
  3. Cross-attn from LN(x_q) queries to LN(x_kv) keys/values, residual.
  4. FFN on LN(x_q): Dense(d_ff) β†’ relu β†’ Dense(D), residual.
  5. Return x_q.reshape(-1).

Build a small nn.Module (DecoderBlock) inside @nn.compact.

Inputs:

  • seed: int.
  • x_q: 2-D (T_q, D).
  • x_kv: 2-D (T_kv, D).
  • num_heads: int H. D divisible by H.
  • d_ff: int FFN hidden dim.

Output: 1-D, the flattened (T_q, D) output.

Hints

flax transformer decoder cross-attention

Sign in to attempt this problem and view the solution.