We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Tied I/O Embed
Why this matters
Weight tying is the trick of using the SAME parameter matrix
for two different roles in a model. The classic case in language
models: tie the input embedding (V, D) to the output projection
head (D, V) by using its transpose. One matrix; two uses.
Why? Two reasons:
-
Parameter savings. In GPT-2 small (
V=50257, D=768), the output head ALONE is 38M params. Tying it to the embedding cuts 38M from the total. In bigger models (V=128k, D=4096) it’s half a billion parameters saved. - Implicit regularization. The model is forced to use one representation for both “what does this token mean as input?” and “how do I score this token at the output?”. Empirically improves perplexity at small to medium scales.
Used by: GPT-2, GPT-3, T5, BART, BERT (kind of), most modern decoder-only LLMs. Untied is sometimes preferred at the largest scales where the regularization is less needed.
The mechanic
embedding: (V, D) param
token_ids: (T,) ints
# Input pathway: lookup (treats embedding's rows as token vectors).
embed = embedding[token_ids] # (T, D)
# Output pathway: project T x D back to T x V.
logits = embed @ embedding.T # (T, D) @ (D, V) -> (T, V)
The transposed matrix is what does “how similar is each
embed[t] to each row of the embedding table?”. Each logit is
an inner product between the position’s hidden state and the
candidate token’s embedding vector — a simple, principled scoring
function that the embedding learning automatically aligns.
For this problem we skip everything between (no transformer blocks); the model is just embedding + tied output. So the logits are pairwise dot products among the input tokens’ embedding vectors.
API: nnx.Param with custom init
nnx.Embed is the off-the-shelf token-embedding module, but the
point here is to be EXPLICIT about the matrix and reuse it. So
we define our own nnx.Param:
class TiedEmbed(nnx.Module):
def __init__(self, vocab_size, d_model, rngs):
key = rngs.params()
self.embedding = nnx.Param(
jax.random.normal(key, (vocab_size, d_model))
* (1.0 / jnp.sqrt(d_model))
)
def __call__(self, token_ids):
embed = self.embedding.value[token_ids] # (T, D)
logits = embed @ self.embedding.value.T # (T, V)
return logits
The 1/sqrt(D) scaling is the standard Glorot/Xavier-flavor
init for embeddings. rngs.params() pulls a params-stream key
from the Rngs container — the canonical way to get a key inside
an nnx Module’s __init__.
Why use .value.T, not .T
nnx.Param wraps the array. Arithmetic ops auto-unwrap (so
embed @ self.embedding does work for the matmul side), but
attribute lookups like .T go to the wrapper, not the array.
The wrapper doesn’t have .T, so you’d get an error.
Pattern: always call .value before any attribute access on the
array (.shape, .T, .dtype, indexing). Arithmetic on the
wrapper (+, -, @) is safe.
Compare with mini-GPT
Pos 43 (nnx-mini-gpt) used the same trick at the END of a full
transformer:
x = self.ln_f(x)
return x @ self.embed.embedding.value.T
Same operation, different surrounding context. Once you’ve seen
embed.embedding.T once, you’ll see it in every modern LM
architecture.
Common pitfalls
-
Cast
token_idsto int32. They arrive as floats from the harness;embedding[token_ids]with float indices either crashes or silently produces zeros depending on the version. Cast withtoken_ids.astype(jnp.int32). -
embed.embeddinginstead ofembed.embedding.value.T. The first is(V, D)— wrong shape for the projection. Always transpose, always unwrap with.value. -
Two separate matrices.
self.embedding = nnx.Param(...); self.head_kernel = nnx.Param(...)— you’ve now got two(V, D)params, doubled the parameter count, and lost the regularization. The whole point is to USE THE SAME matrix. -
Dot product with wrong dims.
embed @ embeddingis(T, D) @ (V, D)— shape mismatch. Need.T:embed @ embedding.T=(T, D) @ (D, V) = (T, V).
Problem
Write tied_io_embed(seed, token_ids, vocab_size, d_model):
-
Define
TiedEmbed(nnx.Module):-
self.embedding = nnx.Param(...)of shape(vocab_size, d_model), normal-init scaled by1/sqrt(d_model). Userngs.params()for the key. -
__call__(token_ids):embed = self.embedding.value[token_ids], thenlogits = embed @ self.embedding.value.T.
-
-
Cast
vocab_size,d_modelto int. Casttoken_idstojnp.int32. Build withnnx.Rngs(int(seed)). -
Return
model(ids).reshape(-1).
Output is flattened logits of shape (T * V,).
Inputs:
-
seed: int (passed as float). -
token_ids: 1-D(T,)of int token ids (passed as floats). -
vocab_size,d_model: ints (passed as floats).
Output: 1-D flattened (T * vocab_size,).
Hints
Sign in to attempt this problem and view the solution.