medium end_to_end

MLM Forward with Tied Output Head

Implement the masked language model (MLM) forward pass with weight tying — same bidirectional transformer encoder as mlm-forward-pass, but the LM head reuses the input embedding matrix as the output projection instead of a separate weight.

Weight Tying

In standard MLM (Task 7), the output head is logits = x @ w_head where w_head is a dedicated (d_model, vocab_size) matrix. With weight tying, that separate matrix is removed and replaced by:

logits = x @ w_emb.T

The same embedding matrix that maps token ids to vectors on input is transposed and used to project hidden states to vocabulary logits on output. This saves vocab_size × d_model parameters — often millions in real models — and typically improves perplexity by acting as a regularizer.

Pioneered by Press & Wolf 2017 (“Using the Output Embedding to Improve Language Models”). Adopted by BERT, GPT-2, T5, and most modern LLMs.

Pipeline

  1. Token embedding: x = w_emb[input_ids] → shape (N, T, d_model).
  2. Add position embeddings: x = x + pos_embed (broadcast over batch).
  3. N pre-LN transformer blocks (bidirectional — NO causal mask): each block applies x = x + MHA(LN(x)) then x = x + FFN(LN(x)).
  4. Tied LM head: logits_all = x @ w_emb.T → shape (N, T, vocab_size).
  5. Gather at masked positions: return logits_all[mask_indicator > 0.5] → shape (M, vocab_size).

Weight Packing

blocks_weights has shape (num_blocks, 6, d_model, d_model). For block i, blocks_weights[i, 0..5] are [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2]. Simplification: d_ff = d_model.

Implementation Constraints

  • No nn.MultiheadAttention, no F.scaled_dot_product_attention, no nn.LayerNorm — implement everything from scratch.
  • Manual layer norm: (x - mean) / sqrt(var + eps), eps=1e-5.
  • Manual GELU: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))).
  • No separate w_head argument — the tied head uses w_emb.T.

References

  • Press & Wolf, “Using the Output Embedding to Improve Language Models”, EACL 2017.
  • Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019.

Hints

mlm bert weight-tying

Sign in to attempt this problem and view the solution.