medium end_to_end

MLM Forward Pass

Implement the masked language model (MLM) forward pass — a bidirectional transformer encoder with an embedding lookup at the start and an LM head at the end. Return logits only at masked positions.

This is the inference half of BERT’s pre-training: given a batch of (possibly corrupted) token ids and a binary mask indicating which positions were selected for loss, produce logits over the vocabulary for each masked token.

Pipeline

  1. Token embedding: x = w_emb[input_ids] → shape (N, T, d_model).
  2. Add position embeddings: x = x + pos_embed (broadcast over batch).
  3. N pre-LN transformer blocks (bidirectional — NO causal mask): each block applies x = x + MHA(LN(x)) then x = x + FFN(LN(x)).
  4. LM head: logits_all = x @ w_head → shape (N, T, vocab_size).
  5. Gather at masked positions: return logits_all[mask_indicator > 0.5] → shape (M, vocab_size) where M = total masked positions in the batch.

Bidirectional attention

There is no causal mask — every token can attend to every other token in the sequence. This is the key difference from a decoder: the model sees the full context when predicting a masked token.

Weight packing

blocks_weights has shape (num_blocks, 6, d_model, d_model). For block i, blocks_weights[i, 0..5] are: [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2].

Simplification: d_ff = d_model so all six weight matrices are (d_model, d_model).

Implementation constraints

  • No nn.MultiheadAttention, no F.scaled_dot_product_attention, no nn.LayerNorm — implement everything from scratch.
  • Manual layer norm: (x - mean) / sqrt(var + eps), eps=1e-5.
  • Manual softmax for attention scores.
  • Manual GELU: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))).

Output convention

The output is a flat (M, vocab_size) tensor where M is the total number of masked positions across all sequences in the batch. The downstream loss is cross-entropy over these M rows.

References

  • Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019.

Hints

mlm bert transformer

Sign in to attempt this problem and view the solution.