We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Mini-LM Capstone — Putting It All Together
The closing problem
This is the final problem of the Flax NNX track. One hundred problems, every primitive of modern decoder-only language modeling, converging here into a single complete artifact: a small but real LLaMA-flavored transformer in pure nnx, end to end.
Every component you’ve built across the last 99 positions is waiting to be plugged in. The architecture is the synthesis of:
-
Token embedding + tied output head —
nnx-implement-embed(pos 29) for the embedding lookup;nnx-tied-io-embed(pos 99) for the trick of using the same(V, D)matrix as both the input table and the output projection. -
Sinusoidal position encoding — formula-based, no parameters,
computed inside
__call__(this is the same closed-formsin/cosfrom the original Transformer paper, seennx-implement-positional-embed, pos 30, for context). -
N pre-LN causal Transformer blocks, each with:
-
RMSNorm — the
nnx-implement-rmsnormyou built in pos 25. Mean-square normalization, one γ vector, no β. Cheaper than LayerNorm, equally accurate. Used by LLaMA, Gemma, Mistral, T5. -
Causal self-attention with RoPE — the
nnx-mha-causalfrom pos 33 combined with thennx-mha-roperotation from pos 40. Q and K rotated by per-position angles BEFORE the score computation; lower-triangular mask BEFORE the softmax. -
SwiGLU FFN — the
nnx-swiglu-ffnfrom pos 50. Three Linears, one element-wise SiLU-gated multiplicative interaction. Used by every modern LLM. -
Pre-LN residual structure — norm goes INSIDE the residual
branch, not after the residual sum. This is the
nnx-transformer-decoder-blockpattern from pos 41.
-
RMSNorm — the
- Final RMSNorm before the output head — the standard LLaMA-style pre-head normalization.
-
Tied output projection —
logits = x @ embedding.T. Same matrix as the input embedding. Halves the param count of the head. Pos 99 set this up.
The architecture, top to bottom
token_ids (T,) int32
→ embedding[token_ids] → tok (T, D)
→ + sinusoidal_pos(T, D) → x (T, D)
↓
[DecoderBlock] x num_layers → x (T, D)
| each block, pre-LN style:
| x = x + CausalRopeMHA(RMSNorm(x))
| x = x + SwiGLU_FFN(RMSNorm(x))
↓
RMSNorm(x) → x (T, D)
↓
x @ embedding.T → logits (T, V)
↓
.reshape(-1) → (T*V,)
Five module-level attribute groups: embedding (the (V, D) Param,
used twice), blocks (nnx.List of DecoderBlock), final_norm
(RMSNorm). The output head has no separate parameters — that’s
the tied trick.
The intro problems revisited
Recall from nnx-module-basics (pos 4) the design philosophy:
“modules are plain Python objects with mutable parameters; tracing
is a separate concern.” Every primitive you’ve built has lived
inside an nnx.Module. Every parameter has been an nnx.Param
with a .value. Every transform has been nnx.split → JAX
transform → nnx.merge.
This capstone composes 99 problems’ worth of those primitives into a single Module. No new tricks. Every line is something you’ve written before.
Component-by-component recap
RMSNorm (pos 25)
class RMSNorm(nnx.Module):
def __init__(self, d, eps, rngs):
self.gamma = nnx.Param(jnp.ones((d,)))
self.eps = eps
def __call__(self, x):
ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
return self.gamma * x / jnp.sqrt(ms + self.eps)
One trainable gamma vector, ones-init. Standard LLaMA-style.
RoPE rotate (pos 40)
def rotate(x, cos, sin):
x1 = x[..., 0::2]
x2 = x[..., 1::2]
rx1 = x1 * cos - x2 * sin
rx2 = x1 * sin + x2 * cos
rotated = jnp.stack([rx1, rx2], axis=-1)
return rotated.reshape(*x.shape)
Pair adjacent feature dimensions, apply 2D rotation by per-position angle, restack. Q and K are rotated; V is not.
Causal MHA with RoPE (pos 33 + pos 40)
class CausalRopeMHA(nnx.Module):
# __init__: head_dim=D//H, four nnx.Linear projections, base=10000.0.
def _cos_sin(self, T):
i = jnp.arange(self.head_dim // 2)
theta = jnp.power(self.base, -2.0 * i / self.head_dim)
angles = jnp.arange(T)[:, None] * theta[None, :]
return jnp.cos(angles), jnp.sin(angles)
def __call__(self, x):
T, _ = x.shape
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
cos, sin = self._cos_sin(T)
q = rotate(q, cos[None, :, :], sin[None, :, :])
k = rotate(k, cos[None, :, :], sin[None, :, :])
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
mask = jnp.tril(jnp.ones((T, T)))
scores = jnp.where(mask == 0, -1e9, scores)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
return self.out_proj(per_head.transpose(1, 0, 2).reshape(T, H * Dh))
The -1e9 mask value (NOT -jnp.inf) is a hard-won lesson from
earlier problems: jnp.inf survives the softmax fine but breaks
JSON serialization through the test harness. -1e9 is large
enough that softmax(-1e9) ≈ 0 to many decimal places.
SwiGLU FFN (pos 50)
class SwiGLU(nnx.Module):
def __init__(self, d_model, d_ff, rngs):
self.gate = nnx.Linear(d_model, d_ff, rngs=rngs)
self.up = nnx.Linear(d_model, d_ff, rngs=rngs)
self.down = nnx.Linear(d_ff, d_model, rngs=rngs)
def __call__(self, x):
return self.down(jax.nn.silu(self.gate(x)) * self.up(x))
Three Linears, one multiplicative gate. SiLU on the gate branch only.
DecoderBlock (pos 41-style)
class DecoderBlock(nnx.Module):
def __init__(self, d_model, num_heads, d_ff, rngs):
self.norm1 = RMSNorm(d_model, eps=1e-6, rngs=rngs)
self.attn = CausalRopeMHA(d_model, num_heads, base=10000.0, rngs=rngs)
self.norm2 = RMSNorm(d_model, eps=1e-6, rngs=rngs)
self.ffn = SwiGLU(d_model, d_ff, rngs=rngs)
def __call__(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
Pre-LN: norm INSIDE the residual branch. The skip path (x itself)
is never normalized — clean gradient flow through the residual.
Sinusoidal position encoding (helper, no params)
def sinusoidal_pos_encoding(T, D):
pos = jnp.arange(T, dtype=jnp.float32)[:, None]
i = jnp.arange(D // 2, dtype=jnp.float32)[None, :]
div = jnp.power(10000.0, (2.0 * i) / float(D))
angles = pos / div
pe = jnp.zeros((T, D), dtype=jnp.float32)
pe = pe.at[:, 0::2].set(jnp.sin(angles))
pe = pe.at[:, 1::2].set(jnp.cos(angles))
return pe
No parameters, no Module — just a function called inside the forward. This is on top of the RoPE inside attention; the sinusoidal encoding adds an absolute position signal at the embedding layer, while RoPE adds a relative-position signal inside scoring. Real LLaMA only has RoPE; we keep both here for pedagogy.
MiniLM (the top-level Module)
class MiniLM(nnx.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, rngs):
self.d_model = d_model
key = rngs.params()
self.embedding = nnx.Param(
jax.random.normal(key, (vocab_size, d_model))
* (1.0 / jnp.sqrt(d_model))
)
self.blocks = nnx.List([
DecoderBlock(d_model, num_heads, d_ff, rngs=rngs)
for _ in range(num_layers)
])
self.final_norm = RMSNorm(d_model, eps=1e-6, rngs=rngs)
def __call__(self, token_ids):
T = token_ids.shape[0]
tok = self.embedding.value[token_ids]
x = tok + sinusoidal_pos_encoding(T, self.d_model)
for block in self.blocks:
x = block(x)
x = self.final_norm(x)
return x @ self.embedding.value.T
embedding is the lone nnx.Param, used twice: as input lookup
(rows indexed by token_ids) and as output projection (matmul
against its transpose).
Why these choices and not others
Pre-LN, not post-LN. Post-LN puts the norm AFTER the residual
sum (LN(x + branch)). It was the original Transformer formulation
— and required a careful learning-rate warmup. Pre-LN (used by
GPT-2+ and every modern LLM) is stable without warmup because the
residual path is never scaled by an activation-dependent factor.
RMSNorm, not LayerNorm. Same accuracy, half the compute. LayerNorm needs both mean and variance; RMSNorm just the mean square. Plus no β bias, which trained models tend to use minimally anyway.
SwiGLU, not vanilla FFN. At fixed parameter count, SwiGLU
consistently outperforms Linear → ReLU → Linear. The
multiplicative gate lets the network learn position-dependent
activation amplification.
RoPE, not learned absolute position. Learned positional embeddings have a fixed maximum length and don’t extrapolate. RoPE is parameter-free, naturally encodes relative positions, and extrapolates (with caveats) beyond training length.
Tied embeddings, not separate output head. Cuts V * D
parameters from the head. In a real LLM with V=128k, D=4096,
that’s half a billion parameters. Plus a small empirical
perplexity improvement.
What you’re NOT doing here
- No KV-cache — this is a forward pass for training/eval, not autoregressive decoding. KV-cache machinery (pos 36, pos 37) is for the generation loop.
- No dropout — disabled for clean numerical reproducibility. In a real training run, dropout (with a per-call rng) would sit inside attention and the FFN.
- No mixed precision — pos 67 covers this; the capstone runs in float32 for clarity.
- No scaling tricks (gradient checkpointing, FSDP, ZeRO) — pos 84-90 cover those. Adding them is mechanical once the base model exists.
Common pitfalls
-
Forgetting
astype(jnp.int32)for token_ids. They arrive as floats;embedding[token_ids]with float indices fails. -
Plain Python list for blocks — must be
nnx.List(...)or params disappear from the state tree. -
Wrong tied-head shape —
logits = x @ embedding.T(NOTembedding). Without.Tthe dimensions don’t match. - Skipping the final norm — required for training stability; every modern LLM has it.
-
-jnp.infin the causal mask — survives softmax but breaks JSON serialization through the test harness. Use-1e9. -
Forgetting
.valuefor the transpose —self.embedding.Toperates on the wrapper (no.T).self.embedding.value.Tis what you want.
Problem
Implement mini_lm_capstone(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):
-
Cast all configs (
vocab_size,d_model,num_heads,d_ff,num_layers) toint. Casttoken_idstojnp.int32. -
Build
RMSNorm,CausalRopeMHA,SwiGLU,DecoderBlock, and a top-levelMiniLMModule.-
MiniLM.embedding:nnx.Paramof shape(V, D), normal-init scaled by1/sqrt(D). -
MiniLM.blocks:nnx.Listofnum_layersDecoderBlocks. -
MiniLM.final_norm:RMSNorm.
-
-
Forward:
tok = embedding.value[ids], addsinusoidal_pos_encoding(T, D), run blocks, final norm, thenlogits = x @ embedding.value.T. -
Return
logits.reshape(-1).
Use nnx.Rngs(int(seed)) once at the top; pass it through to
every submodule.
Test inputs: T=4, V=8, D=8, H=2, d_ff=16, L=2 (and variations).
Inputs:
-
seed: float (cast to int). -
token_ids: 1-D float array of token ids (cast inside). -
vocab_size,d_model,num_heads,d_ff,num_layers: ints (passed as floats).
Output: 1-D, length T * vocab_size.
Closing the track
If you’ve worked through the previous 99 problems, you’ve built
every component of a modern decoder-only language model from
scratch — twice, in two different Flax APIs. You’ve seen how
nnx.split underlies jit and sharding. You’ve seen how
nnx.Variable enables eager debugging. You’ve seen the entire
bridge layer between Linen and nnx, the Orbax checkpoint format,
multi-host data parallelism, RoPE, RMSNorm, SwiGLU, KV-cache,
surgery, freezing, warm-starting.
A real production LLM is mostly this code, scaled up. Same
DecoderBlock, but with D=4096, H=32, d_ff=14336, L=32. Same
RMSNorm. Same SwiGLU with wider hidden dim. Same tied embeddings,
just bigger. What changes between this capstone and a frontier
model is mostly scale, data, and engineering — not
architecture. The architectural primitives are exactly the ones
in this file.
Good luck. Write it once. Read it carefully. Watch the logits come out.
Track 12 closes here.
Hints
Sign in to attempt this problem and view the solution.