We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mini GPT — Decoder-Only Language Model
Why this matters
GPT (Radford et al., 2018) is the canonical decoder-only Transformer. Strip away the scale and you have a clean recipe:
-
Token embedding —
(V, D)lookup. -
Position embedding — learned
(T, D)added to the tokens. - N stacked decoder blocks — each is causal self-attn + FFN, no cross-attention (decoder-only).
- Final LayerNorm — clean up before the head.
-
Tied output head — multiply by
embed.Tinstead of a freshDense(V).
Output: a (T, V) matrix of logits — for each position t, the
score for each vocabulary token. Train with cross-entropy: position
t‘s logits should peak on token[t+1].
Architecture
ids (T,)
→ Embed(V, D) → tok (T, D)
→ + pos_embed → x (T, D)
→ block × N → x (T, D) # causal self-attn + FFN
→ LayerNorm → x (T, D)
→ x @ embed.T → logits (T, V)
Why decoder-only
GPT has no encoder — there’s nothing to “translate from.” Every token attends to its own past via the causal mask, and predicts the next token. Same architecture for next-token prediction at training time and for generation at inference time.
Why tied embeddings
Tying input and output embeddings (Press & Wolf, 2017) is standard
in modern LLMs: instead of a separate (D, V) output projection,
reuse the (V, D) token embedding as embed.T @ x (or x @ embed.T
for our flat case). Saves V·D parameters — significant when V is
50k+ — and empirically improves quality.
embed = nn.Embed(V, D)
tok = embed(ids) # (T, D)
...
logits = x @ embed.embedding.T # (T, V) — tied head
embed.embedding is the (V, D) matrix; transpose to (D, V) for
the matmul.
Worked walk-through
With V=16, D=8, H=2, d_ff=16, L=2, T=4:
-
ids = [0, 3, 7, 1].tok = embed(ids)→(4, 8). -
pos = pos_embed[:T]→(4, 8).x = tok + pos. -
Block 1:
x = x + MHA(LN(x), causal_mask);x = x + FFN(LN(x)). - Block 2: same.
-
x = LayerNorm(x). -
logits = x @ embed.embedding.T→(4, 16). Flatten →(64,).
Common pitfalls
- Forgetting the causal mask: the WHOLE POINT of GPT is causal attention. Without it, training “leaks” future tokens into each position’s prediction.
- Forgetting position embeddings: pure attention is permutation- invariant, so two identical tokens at different positions get identical representations. Add learned positions BEFORE block 1.
-
Not casting
token_idsto int:nn.Embedrequires int dtype. -
Re-creating the embed in the head: tied head means USE the
same
embed.embeddingmatrix, not a freshDense(V). - Skipping the final LayerNorm: minor numerically but architecturally important; modern GPT-2/3 always include it.
Problem
Implement mini_gpt_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):
-
Cast all configs to
int. Casttoken_idstojnp.int32. -
Build a
MiniGPTModule:-
embed = nn.Embed(V, D).tok = embed(ids). -
pos = self.param("pos_embed", normal(0.02), (T, D)). -
x = tok + pos. -
Loop
num_layersGPT blocks (each: causal MHA + FFN, Pre-LN). -
x = nn.LayerNorm()(x). -
logits = x @ embed.embedding.T.
-
-
Init with
jax.random.PRNGKey(seed), apply onids. Returnlogits.reshape(-1).
Inputs:
-
seed: int. -
token_ids: 1-D float array (cast inside). -
vocab_size: V. -
d_model: D. -
num_heads: H. D divisible by H. -
d_ff: FFN hidden. -
num_layers: N stacked blocks.
Output: 1-D, length T * V.
Hints
Sign in to attempt this problem and view the solution.