hard primitives

Mini LM Capstone — Putting It All Together

Why this matters

This is the final problem of the Flax track. Every component you’ve built across the last 99 positions converges here into one artifact: a complete, modern, decoder-only language model in the style of LLaMA / Mistral / Gemma — small enough to verify arithmetically, real enough to produce a logits tensor that could be trained on next-token prediction.

A modern LM is just a stack of well-understood pieces. You’ve seen each in isolation; now you’ll wire them together.

The architecture

token_ids (T,)
  → Embed(V, D)                     → tok      (T, D)
  → + sinusoidal_pos(T, D)          → x        (T, D)
  → DecoderBlock × N                → x        (T, D)
  → RMSNorm()                       → x        (T, D)
  → x @ embed.embedding.T           → logits   (T, V)    # tied head

Each DecoderBlock:

h = RMSNorm(x)
h = CausalMHA(h)                # Pre-LN: norm INSIDE the residual branch
x = x + h
h = RMSNorm(x)
h = SwiGLU_FFN(h)
x = x + h

Component-by-component recap (with track positions)

1. Token embedding (pos 39, flax-token-embed). nn.Embed(V, D) is a (V, D) lookup table. embed(ids) selects rows. Required dtype is integer — cast token_ids.astype(jnp.int32) before calling.

2. Sinusoidal position encoding (pos 40, flax-sinusoidal-pos). Vaswani et al.’s deterministic position formula:

PE[pos, 2i]   = sin(pos / 10000^(2i/D))
PE[pos, 2i+1] = cos(pos / 10000^(2i/D))

No learned parameters; computed once per call. Modern LLMs use RoPE or ALiBi instead, but sinusoidal is the cleanest pedagogical choice and avoids needing a learned (T, D) parameter.

3. RMSNorm (pos 19, flax-implement-rmsnorm). Root-mean-square normalization with one trainable γ vector. No mean subtraction, no β bias. Used by LLaMA, T5, PaLM, Gemma — cheaper than LayerNorm, equally accurate.

4. Causal multi-head attention (pos 24, flax-mha-causal). nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D) with a lower-triangular mask=jnp.tril(jnp.ones((T, T))) so each position attends only to positions ≤ itself. The whole point of a decoder-only LM.

5. SwiGLU FFN (pos 48, flax-swiglu-ffn). Three-Dense gated FFN: silu(gate(x)) * up(x) then down(...). Used by every modern LLM — significantly outperforms the classic Dense → ReLU → Dense for the same parameter count.

6. Pre-LN residual (pos 35, flax-prelayernorm-block). Norm goes INSIDE the residual branch (not after x + h as in the original Transformer’s “post-LN” formulation). Stable training without warmup; standard since GPT-2.

7. Tied output head (pos 41, flax-tied-io-embed, also pos 42’s mini-GPT). Reuse embed.embedding as the output projection instead of a fresh Dense(V). Saves V·D parameters and tends to improve generalization.

8. Stacked blocks. for _ in range(N): x = DecoderBlock(...)(x). Each call creates fresh params (Flax auto-naming gives you DecoderBlock_0, DecoderBlock_1, ...).

Compare to the older mini-GPT (pos 42), which used learned position embeddings, LayerNorm, and a ReLU FFN. This capstone uses the modern set: sinusoidal pos, RMSNorm, SwiGLU. The overall shape is the same — that’s the point. Architecture is modular, swappable.

Worked walk-through

Take case 1: V=16, D=8, H=2, d_ff=16, L=2, T=4, ids=[0,3,7,1].

  1. tok = embed(ids)(4, 8). Each row is the embedding vector of the corresponding token id.
  2. pos = sinusoidal_pos(4, 8)(4, 8). Row 0 starts as [sin 0, cos 0, sin 0, cos 0, ...] = [0, 1, 0, 1, ...] — every second value is 1 because cos(0) = 1.
  3. x = tok + pos(4, 8).
  4. Block 0: x = x + CausalMHA(RMSNorm(x)); x = x + SwiGLU(RMSNorm(x)).
  5. Block 1: same.
  6. x = RMSNorm(x) → final pre-head normalization.
  7. logits = x @ embed.embedding.T(4, 16). Flatten → (64,).

Position t‘s logits row scores each vocabulary token as the probability the model places on it being at position t. During training you compare logits[t] to target = ids[t+1] with cross-entropy.

Architectural choices and their consequences

Why pre-LN, not post-LN? Pre-LN puts the normalization on the residual branch. The skip connection (x itself) is never normalized, so gradients flow through the network without ever being scaled by an activation-dependent factor — much more stable for deep networks. Post-LN (LN(x + branch)) was the original Transformer recipe; it requires a careful learning-rate warmup schedule. Pre-LN (used here, in GPT-2+, in every modern LLM) works without warmup.

Why sinusoidal, not learned, position? For pedagogy and to keep the param count predictable. In production:

  • RoPE (rotary) — applied INSIDE attention, rotates Q and K by position-dependent angles. Used by LLaMA, GPT-NeoX, etc.
  • ALiBi — adds a position-dependent bias to attention logits; no positional embedding at all. Used by some BLOOM variants.
  • Learned absolute position — early GPT-2 style; doesn’t extrapolate beyond training context.

Why RMSNorm, not LayerNorm? Roughly the same accuracy, cheaper compute. One reduction (mean square) instead of two (mean and variance). The β bias term turns out to be barely used by trained models, so RMSNorm just drops it.

Why SwiGLU? Empirically: better loss at fixed parameter count. The gated multiplicative interaction lets the model learn which dimensions to amplify based on the input — strictly more expressive than a single nonlinearity.

Why tied embeddings? Cuts V·D params from the head — that’s 50000 · 4096 ≈ 200M parameters in a real LLM. Empirically improves perplexity even at small scale.

What you’re NOT doing here (and why)

  • No KV-cache: this is a forward pass for training/eval, not autoregressive decoding. The KV-cache would be relevant for a generation loop. (Covered in pos 38, flax-mha-kv-cache.)
  • No dropout: kept off for clean numerical reproducibility. In a real training run, dropout (with a per-call rng) would live inside MHA and the FFN.
  • No mixed precision: pos 67 covers this; the capstone runs in float32 for clarity.
  • No gradient checkpointing: pos 82 covers nn.checkpoint / nn.remat. Useful for fitting deeper models in memory.

Common pitfalls

  • Forgetting astype(jnp.int32): nn.Embed requires int dtype. The function receives token_ids as float; cast it.
  • Forgetting the causal mask: mask=jnp.tril(jnp.ones((T,T))) inside MHA. Without it, you’ve trained a non-causal encoder that cheats by looking at future tokens.
  • Computing position encoding once, freezing it: it’s a function of T, computed each call. Don’t try to make it a learned param.
  • Wrong final norm: a final RMSNorm BEFORE the tied head is critical for training stability. Modern LLMs always include it; skipping it tends to cause loss spikes.
  • Pre-LN inside the block, NOT post-LN: norm goes BEFORE the attention/FFN, not after the residual sum.
  • Tied head shape: logits = x @ embed.embedding.T. The embedding matrix is (V, D), so embedding.T is (D, V). x is (T, D). Result: (T, V). Don’t accidentally use embed.embedding (without transpose).

Problem

Implement mini_lm_capstone(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):

  1. Cast all configs to int. Cast token_ids to jnp.int32.
  2. Build a MiniLM Flax Module:
    • embed = nn.Embed(V, D). tok = embed(ids).
    • pos = sinusoidal_pos_encoding(T, D) (helper function, not a Flax param).
    • x = tok + pos.
    • Loop num_layers DecoderBlocks. Each block:
      • h = RMSNorm()(x); h = CausalMHA(h); x = x + h.
      • h = RMSNorm()(x); h = SwiGLU_FFN(h); x = x + h.
    • Final x = RMSNorm()(x).
    • logits = x @ embed.embedding.T.
  3. Init with jax.random.PRNGKey(seed). Apply on ids. Return logits.reshape(-1).

Use the implementations of RMSNorm (pos 19) and SwiGLU_FFN (pos 48) you’ve built before. The reference solution inlines them.

Inputs:

  • seed: float (cast to int).
  • token_ids: 1-D float array (cast inside).
  • vocab_size: V.
  • d_model: D — must be divisible by num_heads.
  • num_heads: H.
  • d_ff: SwiGLU hidden dim.
  • num_layers: N decoder blocks.

Output: 1-D, length T * V.

A note on finishing the track

If you’re reading this in earnest, you’ve worked through 99 problems on Flax: the Module system, every kind of normalization, every kind of attention, every kind of position encoding, train/eval state, sharding, lifts, surgery. This problem is the closing bracket — the place where the lessons compose.

A real production LLM is mostly this code, scaled up. The same DecoderBlock, just with D=4096, H=32, d_ff=14336, N=32. The same RMSNorm. The same SwiGLU. The same tied embeddings. What changes between a capstone toy and a frontier model is mostly scale, data, and engineering. The architectural primitives are the ones in this file.

Good luck. Ship it.

Hints

flax transformer language-model capstone

Sign in to attempt this problem and view the solution.