medium primitives

Tied Input/Output Embedding

Why this matters

A language model has two big (V, D) matrices: the input embedding that maps token IDs to vectors, and the output projection that maps a hidden vector back to vocabulary logits. With V = 50000, D = 4096 each is 200M parameters — together a tenth of GPT-3 small’s total.

Weight tying (Press & Wolf 2017) just reuses the same matrix:

embed_in[t]  = W[t]            # token → vector
logits_out   = h @ W.T         # vector → vocab logits

The output projection is the transpose of the input embedding. Two benefits:

  1. Halves the largest parameter in the model.
  2. Regularises: input and output share representations, often improving perplexity by a small but real margin.

Used in GPT-2, T5, BART, and most decoder-only LMs that aren’t enormous (very large models sometimes untie because the parameter is a small fraction of total — and untied gives slightly more capacity).

The Flax pattern

Don’t use nn.Embed (it owns its parameter). Declare a single param explicitly and do both directions yourself:

class TiedIO(nn.Module):
    vocab_size: int
    d_model: int

    @nn.compact
    def __call__(self, token_ids):
        embedding = self.param(
            "embedding",
            nn.initializers.normal(stddev=0.02),
            (self.vocab_size, self.d_model),
        )
        x = embedding[token_ids]                 # input lookup: (T, D)
        logits = x @ embedding.T                 # output proj:  (T, V)
        return logits

The crucial line is x @ embedding.T. Same (V, D) matrix, transposed on the way out. There’s only ONE param in the tree.

Why does the same matrix work in both directions?

Conceptually, the input embedding embeds a token ID into a “meaning space”; the output projection asks “how similar is the model’s hidden vector to each token’s embedding?” — the dot product is high for semantically similar tokens. Same geometry, different direction.

Mathematically it’s a stronger inductive bias: each token’s vector serves dual purpose. Empirically this regularises well for small/medium models.

Worked example

V, D = 4, 3
W = init_normal(stddev=0.02, shape=(V, D))     # the ONE matrix
ids = jnp.array([0, 2])                         # T=2

x = W[ids]                                      # (2, 3) — rows 0, 2 of W
logits = x @ W.T                                # (2, 4)
# logits[t, v] = <W[ids[t]], W[v]>             # similarity to every token

Note logits[t, ids[t]] = ||W[ids[t]]||², which is the largest entry when the embeddings are roughly orthogonal — a self-similarity bias.

Common pitfalls

  • Using nn.Embed: it has its own parameter. To tie, you’d need to pass that param explicitly to the output projection — clumsier than declaring embedding once with self.param.
  • Not transposing: x @ embedding has the wrong shape (matmul (T, D) @ (V, D) doesn’t compute). It must be embedding.T for the (D, V) projection.
  • Float token_ids: cast to jnp.int32 first.
  • Two separate self.param calls: that creates two parameters, not tied. Only ONE call.

Problem

Implement tied_io_embed(seed, token_ids, vocab_size, d_model):

  1. Cast vocab_size, d_model to int. Cast token_ids to jnp.int32.
  2. Build a nn.Module (compact) that declares ONE param embedding of shape (V, D), init normal(stddev=0.02).
  3. Forward: lookup embedding[token_ids](T, D), then project with ... @ embedding.T(T, V). Return logits.
  4. Init with jax.random.PRNGKey(seed), apply, return flattened.

Inputs:

  • seed: int.
  • token_ids: 1-D float array (cast inside).
  • vocab_size: int V.
  • d_model: int D.

Output: 1-D array of length T · V (flattened logits).

Hints

flax embedding weight-tying transformers

Sign in to attempt this problem and view the solution.