hard primitives

NNX Mini-LM Capstone — Putting It All Together

The closing problem

This is the final problem of the Flax NNX track. One hundred problems, every primitive of modern decoder-only language modeling, converging here into a single complete artifact: a small but real LLaMA-flavored transformer in pure nnx, end to end.

Every component you’ve built across the last 99 positions is waiting to be plugged in. The architecture is the synthesis of:

  • Token embedding + tied output headnnx-implement-embed (pos 29) for the embedding lookup; nnx-tied-io-embed (pos 99) for the trick of using the same (V, D) matrix as both the input table and the output projection.
  • Sinusoidal position encoding — formula-based, no parameters, computed inside __call__ (this is the same closed-form sin/cos from the original Transformer paper, see nnx-implement-positional-embed, pos 30, for context).
  • N pre-LN causal Transformer blocks, each with:
    • RMSNorm — the nnx-implement-rmsnorm you built in pos 25. Mean-square normalization, one γ vector, no β. Cheaper than LayerNorm, equally accurate. Used by LLaMA, Gemma, Mistral, T5.
    • Causal self-attention with RoPE — the nnx-mha-causal from pos 33 combined with the nnx-mha-rope rotation from pos 40. Q and K rotated by per-position angles BEFORE the score computation; lower-triangular mask BEFORE the softmax.
    • SwiGLU FFN — the nnx-swiglu-ffn from pos 50. Three Linears, one element-wise SiLU-gated multiplicative interaction. Used by every modern LLM.
    • Pre-LN residual structure — norm goes INSIDE the residual branch, not after the residual sum. This is the nnx-transformer-decoder-block pattern from pos 41.
  • Final RMSNorm before the output head — the standard LLaMA-style pre-head normalization.
  • Tied output projectionlogits = x @ embedding.T. Same matrix as the input embedding. Halves the param count of the head. Pos 99 set this up.

The architecture, top to bottom

token_ids (T,) int32
  → embedding[token_ids]              → tok    (T, D)
  → + sinusoidal_pos(T, D)            → x      (T, D)
  ↓
[DecoderBlock] x num_layers           → x      (T, D)
  |   each block, pre-LN style:
  |     x = x + CausalRopeMHA(RMSNorm(x))
  |     x = x + SwiGLU_FFN(RMSNorm(x))
  ↓
RMSNorm(x)                            → x      (T, D)
  ↓
x @ embedding.T                       → logits (T, V)
  ↓
.reshape(-1)                          → (T*V,)

Five module-level attribute groups: embedding (the (V, D) Param, used twice), blocks (nnx.List of DecoderBlock), final_norm (RMSNorm). The output head has no separate parameters — that’s the tied trick.

The intro problems revisited

Recall from nnx-module-basics (pos 4) the design philosophy: “modules are plain Python objects with mutable parameters; tracing is a separate concern.” Every primitive you’ve built has lived inside an nnx.Module. Every parameter has been an nnx.Param with a .value. Every transform has been nnx.split → JAX transform → nnx.merge.

This capstone composes 99 problems’ worth of those primitives into a single Module. No new tricks. Every line is something you’ve written before.

Component-by-component recap

RMSNorm (pos 25)

class RMSNorm(nnx.Module):
    def __init__(self, d, eps, rngs):
        self.gamma = nnx.Param(jnp.ones((d,)))
        self.eps = eps

    def __call__(self, x):
        ms = jnp.mean(x ** 2, axis=-1, keepdims=True)
        return self.gamma * x / jnp.sqrt(ms + self.eps)

One trainable gamma vector, ones-init. Standard LLaMA-style.

RoPE rotate (pos 40)

def rotate(x, cos, sin):
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    rx1 = x1 * cos - x2 * sin
    rx2 = x1 * sin + x2 * cos
    rotated = jnp.stack([rx1, rx2], axis=-1)
    return rotated.reshape(*x.shape)

Pair adjacent feature dimensions, apply 2D rotation by per-position angle, restack. Q and K are rotated; V is not.

Causal MHA with RoPE (pos 33 + pos 40)

class CausalRopeMHA(nnx.Module):
    # __init__: head_dim=D//H, four nnx.Linear projections, base=10000.0.
    def _cos_sin(self, T):
        i = jnp.arange(self.head_dim // 2)
        theta = jnp.power(self.base, -2.0 * i / self.head_dim)
        angles = jnp.arange(T)[:, None] * theta[None, :]
        return jnp.cos(angles), jnp.sin(angles)

    def __call__(self, x):
        T, _ = x.shape
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        cos, sin = self._cos_sin(T)
        q = rotate(q, cos[None, :, :], sin[None, :, :])
        k = rotate(k, cos[None, :, :], sin[None, :, :])
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        mask = jnp.tril(jnp.ones((T, T)))
        scores = jnp.where(mask == 0, -1e9, scores)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        return self.out_proj(per_head.transpose(1, 0, 2).reshape(T, H * Dh))

The -1e9 mask value (NOT -jnp.inf) is a hard-won lesson from earlier problems: jnp.inf survives the softmax fine but breaks JSON serialization through the test harness. -1e9 is large enough that softmax(-1e9) ≈ 0 to many decimal places.

SwiGLU FFN (pos 50)

class SwiGLU(nnx.Module):
    def __init__(self, d_model, d_ff, rngs):
        self.gate = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.up = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.down = nnx.Linear(d_ff, d_model, rngs=rngs)

    def __call__(self, x):
        return self.down(jax.nn.silu(self.gate(x)) * self.up(x))

Three Linears, one multiplicative gate. SiLU on the gate branch only.

DecoderBlock (pos 41-style)

class DecoderBlock(nnx.Module):
    def __init__(self, d_model, num_heads, d_ff, rngs):
        self.norm1 = RMSNorm(d_model, eps=1e-6, rngs=rngs)
        self.attn = CausalRopeMHA(d_model, num_heads, base=10000.0, rngs=rngs)
        self.norm2 = RMSNorm(d_model, eps=1e-6, rngs=rngs)
        self.ffn = SwiGLU(d_model, d_ff, rngs=rngs)

    def __call__(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

Pre-LN: norm INSIDE the residual branch. The skip path (x itself) is never normalized — clean gradient flow through the residual.

Sinusoidal position encoding (helper, no params)

def sinusoidal_pos_encoding(T, D):
    pos = jnp.arange(T, dtype=jnp.float32)[:, None]
    i = jnp.arange(D // 2, dtype=jnp.float32)[None, :]
    div = jnp.power(10000.0, (2.0 * i) / float(D))
    angles = pos / div
    pe = jnp.zeros((T, D), dtype=jnp.float32)
    pe = pe.at[:, 0::2].set(jnp.sin(angles))
    pe = pe.at[:, 1::2].set(jnp.cos(angles))
    return pe

No parameters, no Module — just a function called inside the forward. This is on top of the RoPE inside attention; the sinusoidal encoding adds an absolute position signal at the embedding layer, while RoPE adds a relative-position signal inside scoring. Real LLaMA only has RoPE; we keep both here for pedagogy.

MiniLM (the top-level Module)

class MiniLM(nnx.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, rngs):
        self.d_model = d_model
        key = rngs.params()
        self.embedding = nnx.Param(
            jax.random.normal(key, (vocab_size, d_model))
            * (1.0 / jnp.sqrt(d_model))
        )
        self.blocks = nnx.List([
            DecoderBlock(d_model, num_heads, d_ff, rngs=rngs)
            for _ in range(num_layers)
        ])
        self.final_norm = RMSNorm(d_model, eps=1e-6, rngs=rngs)

    def __call__(self, token_ids):
        T = token_ids.shape[0]
        tok = self.embedding.value[token_ids]
        x = tok + sinusoidal_pos_encoding(T, self.d_model)
        for block in self.blocks:
            x = block(x)
        x = self.final_norm(x)
        return x @ self.embedding.value.T

embedding is the lone nnx.Param, used twice: as input lookup (rows indexed by token_ids) and as output projection (matmul against its transpose).

Why these choices and not others

Pre-LN, not post-LN. Post-LN puts the norm AFTER the residual sum (LN(x + branch)). It was the original Transformer formulation — and required a careful learning-rate warmup. Pre-LN (used by GPT-2+ and every modern LLM) is stable without warmup because the residual path is never scaled by an activation-dependent factor.

RMSNorm, not LayerNorm. Same accuracy, half the compute. LayerNorm needs both mean and variance; RMSNorm just the mean square. Plus no β bias, which trained models tend to use minimally anyway.

SwiGLU, not vanilla FFN. At fixed parameter count, SwiGLU consistently outperforms Linear → ReLU → Linear. The multiplicative gate lets the network learn position-dependent activation amplification.

RoPE, not learned absolute position. Learned positional embeddings have a fixed maximum length and don’t extrapolate. RoPE is parameter-free, naturally encodes relative positions, and extrapolates (with caveats) beyond training length.

Tied embeddings, not separate output head. Cuts V * D parameters from the head. In a real LLM with V=128k, D=4096, that’s half a billion parameters. Plus a small empirical perplexity improvement.

What you’re NOT doing here

  • No KV-cache — this is a forward pass for training/eval, not autoregressive decoding. KV-cache machinery (pos 36, pos 37) is for the generation loop.
  • No dropout — disabled for clean numerical reproducibility. In a real training run, dropout (with a per-call rng) would sit inside attention and the FFN.
  • No mixed precision — pos 67 covers this; the capstone runs in float32 for clarity.
  • No scaling tricks (gradient checkpointing, FSDP, ZeRO) — pos 84-90 cover those. Adding them is mechanical once the base model exists.

Common pitfalls

  • Forgetting astype(jnp.int32) for token_ids. They arrive as floats; embedding[token_ids] with float indices fails.
  • Plain Python list for blocks — must be nnx.List(...) or params disappear from the state tree.
  • Wrong tied-head shapelogits = x @ embedding.T (NOT embedding). Without .T the dimensions don’t match.
  • Skipping the final norm — required for training stability; every modern LLM has it.
  • -jnp.inf in the causal mask — survives softmax but breaks JSON serialization through the test harness. Use -1e9.
  • Forgetting .value for the transposeself.embedding.T operates on the wrapper (no .T). self.embedding.value.T is what you want.

Problem

Implement mini_lm_capstone(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):

  1. Cast all configs (vocab_size, d_model, num_heads, d_ff, num_layers) to int. Cast token_ids to jnp.int32.
  2. Build RMSNorm, CausalRopeMHA, SwiGLU, DecoderBlock, and a top-level MiniLM Module.
    • MiniLM.embedding: nnx.Param of shape (V, D), normal-init scaled by 1/sqrt(D).
    • MiniLM.blocks: nnx.List of num_layers DecoderBlocks.
    • MiniLM.final_norm: RMSNorm.
  3. Forward: tok = embedding.value[ids], add sinusoidal_pos_encoding(T, D), run blocks, final norm, then logits = x @ embedding.value.T.
  4. Return logits.reshape(-1).

Use nnx.Rngs(int(seed)) once at the top; pass it through to every submodule.

Test inputs: T=4, V=8, D=8, H=2, d_ff=16, L=2 (and variations).

Inputs:

  • seed: float (cast to int).
  • token_ids: 1-D float array of token ids (cast inside).
  • vocab_size, d_model, num_heads, d_ff, num_layers: ints (passed as floats).

Output: 1-D, length T * vocab_size.

Closing the track

If you’ve worked through the previous 99 problems, you’ve built every component of a modern decoder-only language model from scratch — twice, in two different Flax APIs. You’ve seen how nnx.split underlies jit and sharding. You’ve seen how nnx.Variable enables eager debugging. You’ve seen the entire bridge layer between Linen and nnx, the Orbax checkpoint format, multi-host data parallelism, RoPE, RMSNorm, SwiGLU, KV-cache, surgery, freezing, warm-starting.

A real production LLM is mostly this code, scaled up. Same DecoderBlock, but with D=4096, H=32, d_ff=14336, L=32. Same RMSNorm. Same SwiGLU with wider hidden dim. Same tied embeddings, just bigger. What changes between this capstone and a frontier model is mostly scale, data, and engineering — not architecture. The architectural primitives are exactly the ones in this file.

Good luck. Write it once. Read it carefully. Watch the logits come out.

Track 12 closes here.

Hints

flax nnx transformer language-model capstone final

Sign in to attempt this problem and view the solution.