medium primitives

Tied Input/Output Embeddings

Implement the tied LM head β€” the output projection used in virtually every modern language model, where the input embedding matrix is reused as the output weight matrix.

Weight tying: the core idea

A language model needs two matrices of shape (vocab_size, d_model):

  1. Input embedding E: maps token ids β†’ dense vectors before the first transformer block.
  2. Output projection W_o: maps the final hidden state β†’ logits over the vocabulary after the last transformer block.

Weight tying (Press & Wolf, 2017, β€œUsing the Output Embedding to Improve Language Models”) sets W_o = E, so both directions share the exact same parameters. The forward pass then becomes:

logits = hidden @ E.T

where E.T has shape (d_model, vocab_size), and logits has shape (..., vocab_size).

Why tie weights?

  • Parameter savings: removes vocab_size Γ— d_model parameters. For GPT-2 (vocab 50 257, d_model 768) that is ~38 M parameters β€” 10 % of the whole model.
  • Regularization / better perplexity: the token vectors and the projection rows are forced to live in the same space, which often improves language-model perplexity compared to separate matrices.
  • Universal adoption: GPT-2, GPT-3, T5, LLaMA, Mistral, and almost every modern LLM uses tied embeddings.

Operation:

Given:

  • hidden: shape (..., d_model) β€” final hidden states from the transformer (any batch/sequence prefix, last dim is d_model).
  • embedding_matrix: shape (vocab_size, d_model) β€” the shared input embedding matrix.

Return: hidden @ embedding_matrix.T of shape (..., vocab_size).

Inputs:

  • hidden: shape (..., d_model) β€” final hidden states.
  • embedding_matrix: shape (vocab_size, d_model) β€” the shared input embedding table.

Output: shape (..., vocab_size) β€” raw (pre-softmax) logits.

Hints

embedding weight-tying transformer

Sign in to attempt this problem and view the solution.