We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
MLM Forward with Tied Output Head
Implement the masked language model (MLM) forward pass with weight tying — same
bidirectional transformer encoder as mlm-forward-pass, but the LM head reuses the
input embedding matrix as the output projection instead of a separate weight.
Weight Tying
In standard MLM (Task 7), the output head is logits = x @ w_head where w_head
is a dedicated (d_model, vocab_size) matrix. With weight tying, that separate
matrix is removed and replaced by:
logits = x @ w_emb.T
The same embedding matrix that maps token ids to vectors on input is transposed
and used to project hidden states to vocabulary logits on output. This saves
vocab_size × d_model parameters — often millions in real models — and typically
improves perplexity by acting as a regularizer.
Pioneered by Press & Wolf 2017 (“Using the Output Embedding to Improve Language Models”). Adopted by BERT, GPT-2, T5, and most modern LLMs.
Pipeline
-
Token embedding:
x = w_emb[input_ids]→ shape(N, T, d_model). -
Add position embeddings:
x = x + pos_embed(broadcast over batch). -
N pre-LN transformer blocks (bidirectional — NO causal mask):
each block applies
x = x + MHA(LN(x))thenx = x + FFN(LN(x)). -
Tied LM head:
logits_all = x @ w_emb.T→ shape(N, T, vocab_size). -
Gather at masked positions: return
logits_all[mask_indicator > 0.5]→ shape(M, vocab_size).
Weight Packing
blocks_weights has shape (num_blocks, 6, d_model, d_model).
For block i, blocks_weights[i, 0..5] are [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2].
Simplification: d_ff = d_model.
Implementation Constraints
-
No
nn.MultiheadAttention, noF.scaled_dot_product_attention, nonn.LayerNorm— implement everything from scratch. -
Manual layer norm:
(x - mean) / sqrt(var + eps), eps=1e-5. -
Manual GELU:
0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))). -
No separate
w_headargument — the tied head usesw_emb.T.
References
- Press & Wolf, “Using the Output Embedding to Improve Language Models”, EACL 2017.
- Devlin et al., “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, NAACL 2019.
Hints
Sign in to attempt this problem and view the solution.