hard end_to_end

Train Causal LM Pretraining Step

Implement one full causal LM pretraining step — the GPT-style next-token-prediction objective that trains modern language models.

This is the training objective behind GPT-1/2/3/4, LLaMA, Mistral, and most modern large language models. Unlike BERT’s masked LM (which only computes loss at masked positions), causal LM computes the loss at every position.

Pipeline

1. Forward pass (same as causal-lm-forward-pass):

x = w_emb[tokens[:, :T]] + pos_embed           # (N, T, d_model)
for each block (POST-LN with causal mask):
    attn_out = MHA_causal(x)
    x = LN(x + attn_out)
    ffn_out = GELU(x @ w_mlp1) @ w_mlp2
    x = LN(x + ffn_out)
logits = x @ w_head                            # (N, T, vocab_size)

2. Loss — next-token CE at every position:

target at position t = tokens[:, t+1]
loss = mean CE over all (N, T) positions
dlogits = (softmax(logits) - one_hot(targets)) / (N*T)

3. Backward by hand — manual chain rule:

  • CE loss → dlogits at every (batch, position)
  • LM head: dw_head = x_final.T @ dlogits_all, dx = dlogits_all @ w_head.T
  • Each block in reverse (POST-LN, causal mask):
    • FFN backward: dw_mlp2, dw_mlp1, through GELU, through LN
    • Attention backward: dw_o, dw_v, dw_k, dw_q, through causal softmax, through LN
  • Embed + pos: dpos_embed = dx.sum(dim=0) dw_emb via index_add_ (each token id accumulates gradient)

4. SGD update on all parameters.

Key difference from MLM

MLM loss: apply only at [MASK] positions. Causal LM loss: apply at every position — shift input by 1 to get targets. This is more data-efficient: a single sequence of length T gives T training signals.

Output

Returns a single flat tensor of all updated weights in order:

[w_emb_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]

Total elements = vocab_size×d_model + T×d_model + num_blocks×6×d_model² + d_model×vocab_size.

Implementation constraints

  • No loss.backward() — implement the backward pass manually.
  • Manual layer norm, GELU, softmax — same helpers as causal-lm-forward-pass.
  • Cache all intermediate activations during the forward pass.

References

  • Radford et al., “Improving Language Understanding by Generative Pre-Training” (GPT-1), OpenAI 2018.

Hints

causal-lm gpt pretraining

Sign in to attempt this problem and view the solution.