We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mini T5 β Encoder-Decoder with RMSNorm and Tied Embeddings
Why this matters
T5 (Raffel et al., 2019) is the encoder-decoder Transformer that
pioneered βeverything is text-to-text.β Translation, summarisation,
classification β all formatted as input text β output text. Two
distinct stacks: an encoder reads the input bidirectionally, a
decoder generates the output autoregressively while cross-attending
to the encoder.
Beyond the encoder/decoder split, T5 introduced design choices that became standard in modern LLMs:
- RMSNorm instead of LayerNorm β strictly simpler, equally good.
-
No learned positional embeddings β T5 uses relative position
bias inside attention. (For brevity we omit relative pos here β
see pos 36
flax-t5-relative-posfor that piece.) -
Weight tying across encoder embed, decoder embed, and the
output projection β one
(V, D)matrix used three ways.
Architecture
enc_ids β embed(V, D) β x_enc β enc_block Γ N_enc β RMSNorm
β
β (encoder context)
β
dec_ids β embed(V, D) β x_dec β dec_block Γ N_dec β RMSNorm
β
Γ embed.T (tied head)
β
logits (T_dec, V)
Encoder block: bidirectional self-attn + FFN, RMSNorm, residuals.
Decoder block: causal self-attn + cross-attn + FFN, RMSNorm,
residuals. Both are Pre-LN style (x + sublayer(RMSNorm(x))).
RMSNorm refresher
From pos 19 (flax-implement-rmsnorm):
gamma = self.param("gamma", ones, (D,))
ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
return gamma * x / jnp.sqrt(ms + eps)
No mean subtraction, no Ξ² bias. Faster than LayerNorm, same loss curves. T5 used it from day one.
Weight tying
ONE nn.Embed(V, D) instance:
-
x_enc = embed(enc_ids)β encoder input embedding. -
x_dec = embed(dec_ids)β decoder input embedding (same matrix). -
logits = x_dec @ embed.embedding.Tβ output head (transposed).
Saves 2Β·VΒ·D parameters and improves quality (all three usages
update the same matrix during training).
Worked walk-through
With V=16, D=8, H=2, d_ff=16, N_enc=2, N_dec=2, T_enc=4, T_dec=3:
-
x_enc = embed([0,3,7,1])β(4, 8). -
Two encoder blocks (bidir self-attn + FFN). Final RMSNorm.
β
x_enc (4, 8). -
x_dec = embed([5,2,8])β(3, 8). -
Two decoder blocks. Each does:
- Causal self-attn on x_dec.
- Cross-attn: Q from x_dec, K/V from x_enc.
- FFN.
-
Final RMSNorm.
logits = x_dec @ embed.Tβ(3, 16). -
Flatten β
(48,).
Common pitfalls
-
Three separate Embed instances: defeats weight tying. Use ONE
nn.Embed, call.embeddingfor the matrix. - LayerNorm instead of RMSNorm: works numerically but isnβt T5.
- Forgetting causal mask on decoder self-attn: same trap as GPT.
- Using cross-attn in encoder: thereβs no encoder-encoder cross. The encoder is pure self-attention.
- Adding learned position embeddings: T5 uses relative pos, not absolute learned. We skip both for simplicity (see pos 36 for relative bias).
Problem
Implement mini_t5_forward(seed, x_enc_ids, x_dec_ids, vocab_size, d_model, num_heads, d_ff, num_enc_layers, num_dec_layers):
-
Cast configs to int. Cast both id arrays to
jnp.int32. -
ONE
nn.Embed(V, D). Encode and decode through it. -
Encoder:
N_encbidirectional blocks (RMSNorm + MHA + FFN, residuals). - RMSNorm on encoder output.
-
Decoder:
N_decdecoder blocks (RMSNorm + causal self-attn + cross-attn + FFN, residuals). - RMSNorm on decoder output.
-
logits = x_dec @ embed.embedding.T. Flatten.
Inputs:
-
seed: int. -
x_enc_ids,x_dec_ids: 1-D float arrays. - All other args: ints.
Output: 1-D, length T_dec * V.
Hints
Sign in to attempt this problem and view the solution.