We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train Causal LM Pretraining Step
Implement one full causal LM pretraining step — the GPT-style next-token-prediction objective that trains modern language models.
This is the training objective behind GPT-1/2/3/4, LLaMA, Mistral, and most modern large language models. Unlike BERT’s masked LM (which only computes loss at masked positions), causal LM computes the loss at every position.
Pipeline
1. Forward pass (same as causal-lm-forward-pass):
x = w_emb[tokens[:, :T]] + pos_embed # (N, T, d_model)
for each block (POST-LN with causal mask):
attn_out = MHA_causal(x)
x = LN(x + attn_out)
ffn_out = GELU(x @ w_mlp1) @ w_mlp2
x = LN(x + ffn_out)
logits = x @ w_head # (N, T, vocab_size)
2. Loss — next-token CE at every position:
target at position t = tokens[:, t+1]
loss = mean CE over all (N, T) positions
dlogits = (softmax(logits) - one_hot(targets)) / (N*T)
3. Backward by hand — manual chain rule:
- CE loss → dlogits at every (batch, position)
-
LM head:
dw_head = x_final.T @ dlogits_all,dx = dlogits_all @ w_head.T -
Each block in reverse (POST-LN, causal mask):
-
FFN backward:
dw_mlp2,dw_mlp1, through GELU, through LN -
Attention backward:
dw_o,dw_v,dw_k,dw_q, through causal softmax, through LN
-
FFN backward:
-
Embed + pos:
dpos_embed = dx.sum(dim=0)dw_embviaindex_add_(each token id accumulates gradient)
4. SGD update on all parameters.
Key difference from MLM
MLM loss: apply only at [MASK] positions.
Causal LM loss: apply at every position — shift input by 1 to get targets.
This is more data-efficient: a single sequence of length T gives T training signals.
Output
Returns a single flat tensor of all updated weights in order:
[w_emb_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]
Total elements = vocab_size×d_model + T×d_model + num_blocks×6×d_model² + d_model×vocab_size.
Implementation constraints
-
No
loss.backward()— implement the backward pass manually. -
Manual layer norm, GELU, softmax — same helpers as
causal-lm-forward-pass. - Cache all intermediate activations during the forward pass.
References
- Radford et al., “Improving Language Understanding by Generative Pre-Training” (GPT-1), OpenAI 2018.
Hints
Sign in to attempt this problem and view the solution.