hard primitives

Mini T5 β€” Encoder-Decoder with RMSNorm and Tied Embeddings

Why this matters

T5 (Raffel et al., 2019) is the encoder-decoder Transformer that pioneered β€œeverything is text-to-text.” Translation, summarisation, classification β€” all formatted as input text β†’ output text. Two distinct stacks: an encoder reads the input bidirectionally, a decoder generates the output autoregressively while cross-attending to the encoder.

Beyond the encoder/decoder split, T5 introduced design choices that became standard in modern LLMs:

  • RMSNorm instead of LayerNorm β€” strictly simpler, equally good.
  • No learned positional embeddings β€” T5 uses relative position bias inside attention. (For brevity we omit relative pos here β€” see pos 36 flax-t5-relative-pos for that piece.)
  • Weight tying across encoder embed, decoder embed, and the output projection β€” one (V, D) matrix used three ways.

Architecture

enc_ids β†’ embed(V, D) β†’ x_enc β†’ enc_block Γ— N_enc β†’ RMSNorm
                                ↓
                                β”‚ (encoder context)
                                ↓
dec_ids β†’ embed(V, D) β†’ x_dec β†’ dec_block Γ— N_dec β†’ RMSNorm
                                ↓
                                Γ— embed.T  (tied head)
                                ↓
                              logits (T_dec, V)

Encoder block: bidirectional self-attn + FFN, RMSNorm, residuals. Decoder block: causal self-attn + cross-attn + FFN, RMSNorm, residuals. Both are Pre-LN style (x + sublayer(RMSNorm(x))).

RMSNorm refresher

From pos 19 (flax-implement-rmsnorm):

gamma = self.param("gamma", ones, (D,))
ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
return gamma * x / jnp.sqrt(ms + eps)

No mean subtraction, no Ξ² bias. Faster than LayerNorm, same loss curves. T5 used it from day one.

Weight tying

ONE nn.Embed(V, D) instance:

  1. x_enc = embed(enc_ids) β€” encoder input embedding.
  2. x_dec = embed(dec_ids) β€” decoder input embedding (same matrix).
  3. logits = x_dec @ embed.embedding.T β€” output head (transposed).

Saves 2Β·VΒ·D parameters and improves quality (all three usages update the same matrix during training).

Worked walk-through

With V=16, D=8, H=2, d_ff=16, N_enc=2, N_dec=2, T_enc=4, T_dec=3:

  1. x_enc = embed([0,3,7,1]) β†’ (4, 8).
  2. Two encoder blocks (bidir self-attn + FFN). Final RMSNorm. β†’ x_enc (4, 8).
  3. x_dec = embed([5,2,8]) β†’ (3, 8).
  4. Two decoder blocks. Each does:
    • Causal self-attn on x_dec.
    • Cross-attn: Q from x_dec, K/V from x_enc.
    • FFN.
  5. Final RMSNorm. logits = x_dec @ embed.T β†’ (3, 16).
  6. Flatten β†’ (48,).

Common pitfalls

  • Three separate Embed instances: defeats weight tying. Use ONE nn.Embed, call .embedding for the matrix.
  • LayerNorm instead of RMSNorm: works numerically but isn’t T5.
  • Forgetting causal mask on decoder self-attn: same trap as GPT.
  • Using cross-attn in encoder: there’s no encoder-encoder cross. The encoder is pure self-attention.
  • Adding learned position embeddings: T5 uses relative pos, not absolute learned. We skip both for simplicity (see pos 36 for relative bias).

Problem

Implement mini_t5_forward(seed, x_enc_ids, x_dec_ids, vocab_size, d_model, num_heads, d_ff, num_enc_layers, num_dec_layers):

  1. Cast configs to int. Cast both id arrays to jnp.int32.
  2. ONE nn.Embed(V, D). Encode and decode through it.
  3. Encoder: N_enc bidirectional blocks (RMSNorm + MHA + FFN, residuals).
  4. RMSNorm on encoder output.
  5. Decoder: N_dec decoder blocks (RMSNorm + causal self-attn + cross-attn + FFN, residuals).
  6. RMSNorm on decoder output.
  7. logits = x_dec @ embed.embedding.T. Flatten.

Inputs:

  • seed: int.
  • x_enc_ids, x_dec_ids: 1-D float arrays.
  • All other args: ints.

Output: 1-D, length T_dec * V.

Hints

flax transformer t5 rmsnorm encoder-decoder

Sign in to attempt this problem and view the solution.