We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Tied Input/Output Embeddings
Implement the tied LM head β the output projection used in virtually every modern language model, where the input embedding matrix is reused as the output weight matrix.
Weight tying: the core idea
A language model needs two matrices of shape (vocab_size, d_model):
-
Input embedding
E: maps token ids β dense vectors before the first transformer block. -
Output projection
W_o: maps the final hidden state β logits over the vocabulary after the last transformer block.
Weight tying (Press & Wolf, 2017, βUsing the Output Embedding to Improve
Language Modelsβ) sets W_o = E, so both directions share the exact
same parameters. The forward pass then becomes:
logits = hidden @ E.T
where E.T has shape (d_model, vocab_size), and logits has shape
(..., vocab_size).
Why tie weights?
-
Parameter savings: removes
vocab_size Γ d_modelparameters. For GPT-2 (vocab 50 257, d_model 768) that is ~38 M parameters β 10 % of the whole model. - Regularization / better perplexity: the token vectors and the projection rows are forced to live in the same space, which often improves language-model perplexity compared to separate matrices.
- Universal adoption: GPT-2, GPT-3, T5, LLaMA, Mistral, and almost every modern LLM uses tied embeddings.
Operation:
Given:
-
hidden: shape(..., d_model)β final hidden states from the transformer (any batch/sequence prefix, last dim isd_model). -
embedding_matrix: shape(vocab_size, d_model)β the shared input embedding matrix.
Return: hidden @ embedding_matrix.T of shape (..., vocab_size).
Inputs:
-
hidden: shape(..., d_model)β final hidden states. -
embedding_matrix: shape(vocab_size, d_model)β the shared input embedding table.
Output: shape (..., vocab_size) β raw (pre-softmax) logits.
Hints
Sign in to attempt this problem and view the solution.