We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train MLM Pretraining Step
Implement one full BERT-style masked language model (MLM) pretraining step — masking, forward pass, cross-entropy loss at masked positions, manual backpropagation, and an SGD parameter update.
This problem composes Tasks 6 + 7 (masking + bidirectional forward) and adds the complete backward pass through the LM head, transformer blocks, embeddings, and position embeddings.
Pipeline
1. BERT Masking (Task 6 algorithm):
gen = torch.Generator().manual_seed(seed)
select_probs = torch.rand(N, T, generator=gen)
corrupt_probs = torch.rand(N, T, generator=gen)
random_tokens = torch.randint(0, vocab_size, (N, T), generator=gen)
Position selected if select_prob < 0.15. For selected positions:
corrupt_prob < 0.8 → mask_token_id, 0.8 ≤ corrupt_prob < 0.9 → random
token, ≥ 0.9 → keep original (still marked selected).
2. Forward Pass — bidirectional (NO causal mask):
x = w_emb[corrupted_ids] + pos_embed # embed
for each block (pre-LN): # N transformer blocks
x = x + MHA(LN(x)) @ w_o
x = x + GELU(LN(x) @ w_mlp1) @ w_mlp2
logits_all = x @ w_head # LM head
3. Loss — CE only at masked positions:
logits_masked = logits_all[mask_indicator > 0.5] # (M, vocab_size)
targets = original_tokens[mask_indicator > 0.5] # (M,)
loss = -mean(log_softmax(logits_masked)[range(M), targets])
The masked-only loss is the key difference from causal LM training: the model is only penalized on positions it was told to predict.
4. Backward by hand — chain rule through:
- CE loss → dlogits_masked (softmax - one_hot) / M
- Scatter dlogits_masked back to dlogits_all (zero at unmasked positions)
-
LM head:
dw_head = x_final.T @ dlogits_all,dx = dlogits_all @ w_head.T - Each block in reverse order: FFN backward then attention backward (same as Task 4 ViT backward)
-
Embedding:
dw_embaccumulates viaindex_add_(sparse update — each token id receives the gradient only at its lookup position);dpos_embed = dx.sum(0)
5. SGD update on all parameters.
Output
Returns a single flat tensor of all updated weights in order:
[w_emb_flat, pos_embed_flat, blocks_weights_flat, w_head_flat]
Total elements = vocab_size×d_model + T×d_model + num_blocks×6×d_model² + d_model×vocab_size.
Implementation constraints
-
No
loss.backward()— implement the backward pass manually. -
No
torch.distributions.Categorical— use uniform random as in Task 6. - Manual layer norm, GELU, softmax — same helpers as Task 7.
- Cache all intermediate activations during the forward pass.
References
- Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019.
Hints
Sign in to attempt this problem and view the solution.