We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Tied Input/Output Embedding
Why this matters
A language model has two big (V, D) matrices: the input embedding
that maps token IDs to vectors, and the output projection that maps
a hidden vector back to vocabulary logits. With V = 50000, D = 4096
each is 200M parameters — together a tenth of GPT-3 small’s total.
Weight tying (Press & Wolf 2017) just reuses the same matrix:
embed_in[t] = W[t] # token → vector
logits_out = h @ W.T # vector → vocab logits
The output projection is the transpose of the input embedding. Two benefits:
- Halves the largest parameter in the model.
- Regularises: input and output share representations, often improving perplexity by a small but real margin.
Used in GPT-2, T5, BART, and most decoder-only LMs that aren’t enormous (very large models sometimes untie because the parameter is a small fraction of total — and untied gives slightly more capacity).
The Flax pattern
Don’t use nn.Embed (it owns its parameter). Declare a single param
explicitly and do both directions yourself:
class TiedIO(nn.Module):
vocab_size: int
d_model: int
@nn.compact
def __call__(self, token_ids):
embedding = self.param(
"embedding",
nn.initializers.normal(stddev=0.02),
(self.vocab_size, self.d_model),
)
x = embedding[token_ids] # input lookup: (T, D)
logits = x @ embedding.T # output proj: (T, V)
return logits
The crucial line is x @ embedding.T. Same (V, D) matrix, transposed
on the way out. There’s only ONE param in the tree.
Why does the same matrix work in both directions?
Conceptually, the input embedding embeds a token ID into a “meaning space”; the output projection asks “how similar is the model’s hidden vector to each token’s embedding?” — the dot product is high for semantically similar tokens. Same geometry, different direction.
Mathematically it’s a stronger inductive bias: each token’s vector serves dual purpose. Empirically this regularises well for small/medium models.
Worked example
V, D = 4, 3
W = init_normal(stddev=0.02, shape=(V, D)) # the ONE matrix
ids = jnp.array([0, 2]) # T=2
x = W[ids] # (2, 3) — rows 0, 2 of W
logits = x @ W.T # (2, 4)
# logits[t, v] = <W[ids[t]], W[v]> # similarity to every token
Note logits[t, ids[t]] = ||W[ids[t]]||², which is the largest entry
when the embeddings are roughly orthogonal — a self-similarity bias.
Common pitfalls
-
Using
nn.Embed: it has its own parameter. To tie, you’d need to pass that param explicitly to the output projection — clumsier than declaringembeddingonce withself.param. -
Not transposing:
x @ embeddinghas the wrong shape (matmul(T, D) @ (V, D)doesn’t compute). It must beembedding.Tfor the(D, V)projection. -
Float
token_ids: cast tojnp.int32first. -
Two separate
self.paramcalls: that creates two parameters, not tied. Only ONE call.
Problem
Implement tied_io_embed(seed, token_ids, vocab_size, d_model):
-
Cast
vocab_size,d_modeltoint. Casttoken_idstojnp.int32. -
Build a
nn.Module(compact) that declares ONE paramembeddingof shape(V, D), initnormal(stddev=0.02). -
Forward: lookup
embedding[token_ids]→(T, D), then project with... @ embedding.T→(T, V). Return logits. -
Init with
jax.random.PRNGKey(seed), apply, return flattened.
Inputs:
-
seed: int. -
token_ids: 1-D float array (cast inside). -
vocab_size: int V. -
d_model: int D.
Output: 1-D array of length T · V (flattened logits).
Hints
Sign in to attempt this problem and view the solution.