medium primitives

NNX Tied I/O Embed

Why this matters

Weight tying is the trick of using the SAME parameter matrix for two different roles in a model. The classic case in language models: tie the input embedding (V, D) to the output projection head (D, V) by using its transpose. One matrix; two uses.

Why? Two reasons:

  1. Parameter savings. In GPT-2 small (V=50257, D=768), the output head ALONE is 38M params. Tying it to the embedding cuts 38M from the total. In bigger models (V=128k, D=4096) it’s half a billion parameters saved.
  2. Implicit regularization. The model is forced to use one representation for both “what does this token mean as input?” and “how do I score this token at the output?”. Empirically improves perplexity at small to medium scales.

Used by: GPT-2, GPT-3, T5, BART, BERT (kind of), most modern decoder-only LLMs. Untied is sometimes preferred at the largest scales where the regularization is less needed.

The mechanic

embedding: (V, D) param
token_ids: (T,) ints

# Input pathway: lookup (treats embedding's rows as token vectors).
embed = embedding[token_ids]            # (T, D)

# Output pathway: project T x D back to T x V.
logits = embed @ embedding.T            # (T, D) @ (D, V) -> (T, V)

The transposed matrix is what does “how similar is each embed[t] to each row of the embedding table?”. Each logit is an inner product between the position’s hidden state and the candidate token’s embedding vector — a simple, principled scoring function that the embedding learning automatically aligns.

For this problem we skip everything between (no transformer blocks); the model is just embedding + tied output. So the logits are pairwise dot products among the input tokens’ embedding vectors.

API: nnx.Param with custom init

nnx.Embed is the off-the-shelf token-embedding module, but the point here is to be EXPLICIT about the matrix and reuse it. So we define our own nnx.Param:

class TiedEmbed(nnx.Module):
    def __init__(self, vocab_size, d_model, rngs):
        key = rngs.params()
        self.embedding = nnx.Param(
            jax.random.normal(key, (vocab_size, d_model))
            * (1.0 / jnp.sqrt(d_model))
        )

    def __call__(self, token_ids):
        embed = self.embedding.value[token_ids]      # (T, D)
        logits = embed @ self.embedding.value.T      # (T, V)
        return logits

The 1/sqrt(D) scaling is the standard Glorot/Xavier-flavor init for embeddings. rngs.params() pulls a params-stream key from the Rngs container — the canonical way to get a key inside an nnx Module’s __init__.

Why use .value.T, not .T

nnx.Param wraps the array. Arithmetic ops auto-unwrap (so embed @ self.embedding does work for the matmul side), but attribute lookups like .T go to the wrapper, not the array. The wrapper doesn’t have .T, so you’d get an error.

Pattern: always call .value before any attribute access on the array (.shape, .T, .dtype, indexing). Arithmetic on the wrapper (+, -, @) is safe.

Compare with mini-GPT

Pos 43 (nnx-mini-gpt) used the same trick at the END of a full transformer:

x = self.ln_f(x)
return x @ self.embed.embedding.value.T

Same operation, different surrounding context. Once you’ve seen embed.embedding.T once, you’ll see it in every modern LM architecture.

Common pitfalls

  • Cast token_ids to int32. They arrive as floats from the harness; embedding[token_ids] with float indices either crashes or silently produces zeros depending on the version. Cast with token_ids.astype(jnp.int32).
  • embed.embedding instead of embed.embedding.value.T. The first is (V, D) — wrong shape for the projection. Always transpose, always unwrap with .value.
  • Two separate matrices. self.embedding = nnx.Param(...); self.head_kernel = nnx.Param(...) — you’ve now got two (V, D) params, doubled the parameter count, and lost the regularization. The whole point is to USE THE SAME matrix.
  • Dot product with wrong dims. embed @ embedding is (T, D) @ (V, D) — shape mismatch. Need .T: embed @ embedding.T = (T, D) @ (D, V) = (T, V).

Problem

Write tied_io_embed(seed, token_ids, vocab_size, d_model):

  1. Define TiedEmbed(nnx.Module):
    • self.embedding = nnx.Param(...) of shape (vocab_size, d_model), normal-init scaled by 1/sqrt(d_model). Use rngs.params() for the key.
    • __call__(token_ids): embed = self.embedding.value[token_ids], then logits = embed @ self.embedding.value.T.
  2. Cast vocab_size, d_model to int. Cast token_ids to jnp.int32. Build with nnx.Rngs(int(seed)).
  3. Return model(ids).reshape(-1).

Output is flattened logits of shape (T * V,).

Inputs:

  • seed: int (passed as float).
  • token_ids: 1-D (T,) of int token ids (passed as floats).
  • vocab_size, d_model: ints (passed as floats).

Output: 1-D flattened (T * vocab_size,).

Hints

flax nnx embedding weight-tying language-model

Sign in to attempt this problem and view the solution.