We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mini LM Capstone — Putting It All Together
Why this matters
This is the final problem of the Flax track. Every component you’ve built across the last 99 positions converges here into one artifact: a complete, modern, decoder-only language model in the style of LLaMA / Mistral / Gemma — small enough to verify arithmetically, real enough to produce a logits tensor that could be trained on next-token prediction.
A modern LM is just a stack of well-understood pieces. You’ve seen each in isolation; now you’ll wire them together.
The architecture
token_ids (T,)
→ Embed(V, D) → tok (T, D)
→ + sinusoidal_pos(T, D) → x (T, D)
→ DecoderBlock × N → x (T, D)
→ RMSNorm() → x (T, D)
→ x @ embed.embedding.T → logits (T, V) # tied head
Each DecoderBlock:
h = RMSNorm(x)
h = CausalMHA(h) # Pre-LN: norm INSIDE the residual branch
x = x + h
h = RMSNorm(x)
h = SwiGLU_FFN(h)
x = x + h
Component-by-component recap (with track positions)
1. Token embedding (pos 39, flax-token-embed).
nn.Embed(V, D) is a (V, D) lookup table. embed(ids) selects
rows. Required dtype is integer — cast token_ids.astype(jnp.int32)
before calling.
2. Sinusoidal position encoding (pos 40, flax-sinusoidal-pos).
Vaswani et al.’s deterministic position formula:
PE[pos, 2i] = sin(pos / 10000^(2i/D))
PE[pos, 2i+1] = cos(pos / 10000^(2i/D))
No learned parameters; computed once per call. Modern LLMs use
RoPE or ALiBi instead, but sinusoidal is the cleanest pedagogical
choice and avoids needing a learned (T, D) parameter.
3. RMSNorm (pos 19, flax-implement-rmsnorm).
Root-mean-square normalization with one trainable γ vector. No
mean subtraction, no β bias. Used by LLaMA, T5, PaLM, Gemma —
cheaper than LayerNorm, equally accurate.
4. Causal multi-head attention (pos 24, flax-mha-causal).
nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
with a lower-triangular mask=jnp.tril(jnp.ones((T, T))) so each
position attends only to positions ≤ itself. The whole point of
a decoder-only LM.
5. SwiGLU FFN (pos 48, flax-swiglu-ffn).
Three-Dense gated FFN: silu(gate(x)) * up(x) then down(...).
Used by every modern LLM — significantly outperforms the classic
Dense → ReLU → Dense for the same parameter count.
6. Pre-LN residual (pos 35, flax-prelayernorm-block).
Norm goes INSIDE the residual branch (not after x + h as in the
original Transformer’s “post-LN” formulation). Stable training
without warmup; standard since GPT-2.
7. Tied output head (pos 41, flax-tied-io-embed, also pos 42’s
mini-GPT). Reuse embed.embedding as the output projection
instead of a fresh Dense(V). Saves V·D parameters and tends
to improve generalization.
8. Stacked blocks. for _ in range(N): x = DecoderBlock(...)(x).
Each call creates fresh params (Flax auto-naming gives you
DecoderBlock_0, DecoderBlock_1, ...).
Compare to the older mini-GPT (pos 42), which used learned position embeddings, LayerNorm, and a ReLU FFN. This capstone uses the modern set: sinusoidal pos, RMSNorm, SwiGLU. The overall shape is the same — that’s the point. Architecture is modular, swappable.
Worked walk-through
Take case 1: V=16, D=8, H=2, d_ff=16, L=2, T=4, ids=[0,3,7,1].
-
tok = embed(ids)→(4, 8). Each row is the embedding vector of the corresponding token id. -
pos = sinusoidal_pos(4, 8)→(4, 8). Row 0 starts as[sin 0, cos 0, sin 0, cos 0, ...] = [0, 1, 0, 1, ...]— every second value is 1 because cos(0) = 1. -
x = tok + pos→(4, 8). -
Block 0:
x = x + CausalMHA(RMSNorm(x));x = x + SwiGLU(RMSNorm(x)). - Block 1: same.
-
x = RMSNorm(x)→ final pre-head normalization. -
logits = x @ embed.embedding.T→(4, 16). Flatten →(64,).
Position t‘s logits row scores each vocabulary token as the
probability the model places on it being at position t. During
training you compare logits[t] to target = ids[t+1] with
cross-entropy.
Architectural choices and their consequences
Why pre-LN, not post-LN? Pre-LN puts the normalization on the
residual branch. The skip connection (x itself) is never
normalized, so gradients flow through the network without ever
being scaled by an activation-dependent factor — much more stable
for deep networks. Post-LN (LN(x + branch)) was the original
Transformer recipe; it requires a careful learning-rate warmup
schedule. Pre-LN (used here, in GPT-2+, in every modern LLM)
works without warmup.
Why sinusoidal, not learned, position? For pedagogy and to keep the param count predictable. In production:
- RoPE (rotary) — applied INSIDE attention, rotates Q and K by position-dependent angles. Used by LLaMA, GPT-NeoX, etc.
- ALiBi — adds a position-dependent bias to attention logits; no positional embedding at all. Used by some BLOOM variants.
- Learned absolute position — early GPT-2 style; doesn’t extrapolate beyond training context.
Why RMSNorm, not LayerNorm? Roughly the same accuracy, cheaper compute. One reduction (mean square) instead of two (mean and variance). The β bias term turns out to be barely used by trained models, so RMSNorm just drops it.
Why SwiGLU? Empirically: better loss at fixed parameter count. The gated multiplicative interaction lets the model learn which dimensions to amplify based on the input — strictly more expressive than a single nonlinearity.
Why tied embeddings? Cuts V·D params from the head — that’s
50000 · 4096 ≈ 200M parameters in a real LLM. Empirically
improves perplexity even at small scale.
What you’re NOT doing here (and why)
-
No KV-cache: this is a forward pass for training/eval, not
autoregressive decoding. The KV-cache would be relevant for a
generation loop. (Covered in pos 38,
flax-mha-kv-cache.) - No dropout: kept off for clean numerical reproducibility. In a real training run, dropout (with a per-call rng) would live inside MHA and the FFN.
- No mixed precision: pos 67 covers this; the capstone runs in float32 for clarity.
-
No gradient checkpointing: pos 82 covers
nn.checkpoint/nn.remat. Useful for fitting deeper models in memory.
Common pitfalls
-
Forgetting
astype(jnp.int32):nn.Embedrequires int dtype. The function receivestoken_idsas float; cast it. -
Forgetting the causal mask:
mask=jnp.tril(jnp.ones((T,T)))inside MHA. Without it, you’ve trained a non-causal encoder that cheats by looking at future tokens. -
Computing position encoding once, freezing it: it’s a
function of
T, computed each call. Don’t try to make it a learned param. -
Wrong final norm: a final
RMSNormBEFORE the tied head is critical for training stability. Modern LLMs always include it; skipping it tends to cause loss spikes. - Pre-LN inside the block, NOT post-LN: norm goes BEFORE the attention/FFN, not after the residual sum.
-
Tied head shape:
logits = x @ embed.embedding.T. Theembeddingmatrix is(V, D), soembedding.Tis(D, V).xis(T, D). Result:(T, V). Don’t accidentally useembed.embedding(without transpose).
Problem
Implement mini_lm_capstone(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):
-
Cast all configs to
int. Casttoken_idstojnp.int32. -
Build a
MiniLMFlax Module:-
embed = nn.Embed(V, D).tok = embed(ids). -
pos = sinusoidal_pos_encoding(T, D)(helper function, not a Flax param). -
x = tok + pos. -
Loop
num_layersDecoderBlocks. Each block:-
h = RMSNorm()(x); h = CausalMHA(h); x = x + h. -
h = RMSNorm()(x); h = SwiGLU_FFN(h); x = x + h.
-
-
Final
x = RMSNorm()(x). -
logits = x @ embed.embedding.T.
-
-
Init with
jax.random.PRNGKey(seed). Apply onids. Returnlogits.reshape(-1).
Use the implementations of RMSNorm (pos 19) and SwiGLU_FFN
(pos 48) you’ve built before. The reference solution inlines
them.
Inputs:
-
seed: float (cast to int). -
token_ids: 1-D float array (cast inside). -
vocab_size: V. -
d_model: D — must be divisible bynum_heads. -
num_heads: H. -
d_ff: SwiGLU hidden dim. -
num_layers: N decoder blocks.
Output: 1-D, length T * V.
A note on finishing the track
If you’re reading this in earnest, you’ve worked through 99 problems on Flax: the Module system, every kind of normalization, every kind of attention, every kind of position encoding, train/eval state, sharding, lifts, surgery. This problem is the closing bracket — the place where the lessons compose.
A real production LLM is mostly this code, scaled up. The same
DecoderBlock, just with D=4096, H=32, d_ff=14336, N=32. The
same RMSNorm. The same SwiGLU. The same tied embeddings. What
changes between a capstone toy and a frontier model is mostly
scale, data, and engineering. The architectural primitives
are the ones in this file.
Good luck. Ship it.
Hints
Sign in to attempt this problem and view the solution.