We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
MLM Masking Strategy
Implement BERT’s masked language model (MLM) masking strategy from Devlin et al. 2018.
Before pre-training, BERT corrupts ~15% of input tokens using an 80/10/10 split:
-
80% of selected positions → replaced with
[MASK]token id. - 10% of selected positions → replaced with a random vocabulary token (forces the model to be robust to noise at every position).
- 10% of selected positions → kept as the original token (forces representations to remain useful even at unmasked positions).
Only selected positions contribute to the pre-training loss. The output includes both the corrupted token ids and a binary indicator mask so the training loop can compute loss only on the selected positions.
The 15% rate
The 15% target rate balances enough signal for the model to learn vs not corrupting too much of the input context. Too low and the model rarely gets training signal; too high and the corrupted context becomes uninformative.
Implementation
Use a single seeded torch.Generator for all three random draws in order:
gen = torch.Generator().manual_seed(seed)
select_probs = torch.rand(N, T, generator=gen) # which positions
corrupt_probs = torch.rand(N, T, generator=gen) # 80/10/10 split
random_tokens = torch.randint(0, vocab_size, (N, T), generator=gen)
-
Position selected if
select_prob < mask_prob. -
For selected positions:
-
corrupt_prob < 0.8→ replace withmask_token_id -
0.8 ≤ corrupt_prob < 0.9→ replace withrandom_token -
corrupt_prob ≥ 0.9→ keep original (but still mark as selected in the indicator)
-
Output shape
Return a tensor of shape (N, T, 2):
-
Channel 0: corrupted token ids (cast to
float). -
Channel 1:
1.0if the position was selected for loss,0.0otherwise.
Reference
Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019.
Hints
Sign in to attempt this problem and view the solution.