hard primitives

Mini GPT — Decoder-Only Language Model

Why this matters

GPT (Radford et al., 2018) is the canonical decoder-only Transformer. Strip away the scale and you have a clean recipe:

  1. Token embedding(V, D) lookup.
  2. Position embedding — learned (T, D) added to the tokens.
  3. N stacked decoder blocks — each is causal self-attn + FFN, no cross-attention (decoder-only).
  4. Final LayerNorm — clean up before the head.
  5. Tied output head — multiply by embed.T instead of a fresh Dense(V).

Output: a (T, V) matrix of logits — for each position t, the score for each vocabulary token. Train with cross-entropy: position t‘s logits should peak on token[t+1].

Architecture

ids (T,)
→ Embed(V, D)        → tok (T, D)
→ + pos_embed        → x   (T, D)
→ block × N          → x   (T, D)        # causal self-attn + FFN
→ LayerNorm          → x   (T, D)
→ x @ embed.T        → logits (T, V)

Why decoder-only

GPT has no encoder — there’s nothing to “translate from.” Every token attends to its own past via the causal mask, and predicts the next token. Same architecture for next-token prediction at training time and for generation at inference time.

Why tied embeddings

Tying input and output embeddings (Press & Wolf, 2017) is standard in modern LLMs: instead of a separate (D, V) output projection, reuse the (V, D) token embedding as embed.T @ x (or x @ embed.T for our flat case). Saves V·D parameters — significant when V is 50k+ — and empirically improves quality.

embed = nn.Embed(V, D)
tok = embed(ids)                      # (T, D)
...
logits = x @ embed.embedding.T        # (T, V) — tied head

embed.embedding is the (V, D) matrix; transpose to (D, V) for the matmul.

Worked walk-through

With V=16, D=8, H=2, d_ff=16, L=2, T=4:

  1. ids = [0, 3, 7, 1]. tok = embed(ids)(4, 8).
  2. pos = pos_embed[:T](4, 8). x = tok + pos.
  3. Block 1: x = x + MHA(LN(x), causal_mask); x = x + FFN(LN(x)).
  4. Block 2: same.
  5. x = LayerNorm(x).
  6. logits = x @ embed.embedding.T(4, 16). Flatten → (64,).

Common pitfalls

  • Forgetting the causal mask: the WHOLE POINT of GPT is causal attention. Without it, training “leaks” future tokens into each position’s prediction.
  • Forgetting position embeddings: pure attention is permutation- invariant, so two identical tokens at different positions get identical representations. Add learned positions BEFORE block 1.
  • Not casting token_ids to int: nn.Embed requires int dtype.
  • Re-creating the embed in the head: tied head means USE the same embed.embedding matrix, not a fresh Dense(V).
  • Skipping the final LayerNorm: minor numerically but architecturally important; modern GPT-2/3 always include it.

Problem

Implement mini_gpt_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):

  1. Cast all configs to int. Cast token_ids to jnp.int32.
  2. Build a MiniGPT Module:
    • embed = nn.Embed(V, D). tok = embed(ids).
    • pos = self.param("pos_embed", normal(0.02), (T, D)).
    • x = tok + pos.
    • Loop num_layers GPT blocks (each: causal MHA + FFN, Pre-LN).
    • x = nn.LayerNorm()(x).
    • logits = x @ embed.embedding.T.
  3. Init with jax.random.PRNGKey(seed), apply on ids. Return logits.reshape(-1).

Inputs:

  • seed: int.
  • token_ids: 1-D float array (cast inside).
  • vocab_size: V.
  • d_model: D.
  • num_heads: H. D divisible by H.
  • d_ff: FFN hidden.
  • num_layers: N stacked blocks.

Output: 1-D, length T * V.

Hints

flax transformer gpt language-model

Sign in to attempt this problem and view the solution.