We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Causal LM Forward Pass
Implement the causal language model (LM) forward pass — the GPT-style architecture that generates text one token at a time.
This is the decoder-only transformer used in GPT-1, GPT-2, and similar models. Unlike BERT (bidirectional), each token can only attend to itself and earlier tokens — enforced via a causal mask.
Pipeline
-
Token + position embedding:
x = w_emb[input_ids] + pos_embed→ shape(N, T, d_model). -
N POST-LN transformer blocks (causal — lower-triangular mask):
each block applies:
-
x = LN(x + MHA(x))— causal attention residual, then layer norm -
x = LN(x + FFN(x))— feed-forward residual, then layer norm
-
-
LM head:
logits = x @ w_head→ shape(N, T, vocab_size).
Causal Mask
Build a (T, T) lower-triangular matrix of ones (torch.tril), then use
masked_fill(mask == 0, -1e9) to block future positions before softmax.
Token at position i attends to positions 0..i only.
POST-LN Convention
This problem uses the original Transformer (POST-LN) convention from Vaswani et al. 2017 and GPT-1/2:
attn_out = MHA(x) # raw attention output
x = LN(x + attn_out) # residual THEN layer norm
ffn_out = FFN(x) # GELU feed-forward
x = LN(x + ffn_out) # residual THEN layer norm
Modern LLMs (LLaMA, GPT-3+) use PRE-LN (x = x + sub(LN(x))); that is a
separate problem.
Weight Packing
blocks_weights has shape (num_blocks, 6, d_model, d_model).
For block i, blocks_weights[i, 0..5] are:
[w_q, w_k, w_v, w_o, w_mlp1, w_mlp2].
Simplification: d_ff = d_model so all six weight matrices are
(d_model, d_model).
Implementation Constraints
-
No
nn.MultiheadAttention, noF.scaled_dot_product_attention, nonn.LayerNorm— implement everything from scratch. -
Manual layer norm:
(x - mean) / sqrt(var + eps), eps=1e-5. - Manual softmax for attention scores.
-
Manual GELU:
0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))).
Output
Return logits at every position: shape (N, T, vocab_size).
At inference, you autoregressively read off the last-position logits.
During training, you compute cross-entropy at all positions shifted by one
(target at position t is the input token at position t+1).
References
- Radford et al., “Improving Language Understanding by Generative Pre-Training” (GPT-1), OpenAI 2018.
- Radford et al., “Language Models are Unsupervised Multitask Learners” (GPT-2), OpenAI 2019.
Hints
Sign in to attempt this problem and view the solution.