medium end_to_end

MLM Masking Strategy

Implement BERT’s masked language model (MLM) masking strategy from Devlin et al. 2018.

Before pre-training, BERT corrupts ~15% of input tokens using an 80/10/10 split:

  • 80% of selected positions → replaced with [MASK] token id.
  • 10% of selected positions → replaced with a random vocabulary token (forces the model to be robust to noise at every position).
  • 10% of selected positions → kept as the original token (forces representations to remain useful even at unmasked positions).

Only selected positions contribute to the pre-training loss. The output includes both the corrupted token ids and a binary indicator mask so the training loop can compute loss only on the selected positions.

The 15% rate

The 15% target rate balances enough signal for the model to learn vs not corrupting too much of the input context. Too low and the model rarely gets training signal; too high and the corrupted context becomes uninformative.

Implementation

Use a single seeded torch.Generator for all three random draws in order:

gen = torch.Generator().manual_seed(seed)
select_probs  = torch.rand(N, T, generator=gen)       # which positions
corrupt_probs = torch.rand(N, T, generator=gen)       # 80/10/10 split
random_tokens = torch.randint(0, vocab_size, (N, T), generator=gen)
  • Position selected if select_prob < mask_prob.
  • For selected positions:
    • corrupt_prob < 0.8 → replace with mask_token_id
    • 0.8 ≤ corrupt_prob < 0.9 → replace with random_token
    • corrupt_prob ≥ 0.9 → keep original (but still mark as selected in the indicator)

Output shape

Return a tensor of shape (N, T, 2):

  • Channel 0: corrupted token ids (cast to float).
  • Channel 1: 1.0 if the position was selected for loss, 0.0 otherwise.

Reference

Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019.

Hints

mlm bert masking

Sign in to attempt this problem and view the solution.