We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Implement Embed
Why this matters
Token embeddings are how every transformer turns discrete vocabulary
indices into continuous vectors the rest of the model can do math on.
BERT, GPT-2, GPT-3, LLaMA, T5 — every one of them starts with this
layer. Embedding dimension d_model and vocabulary size vocab_size
determine the bulk of the model’s parameter count for small models
(most of GPT-2’s 1.5B parameters live in two embeddings).
Reimplementing it pins down what the layer actually IS: a learned
(vocab_size, d_model) matrix, indexed by token IDs to produce
embeddings. There’s no nonlinearity, no initialization scale magic
in the spec — just a lookup table that gets updated by gradient
descent.
API
One trainable parameter: the embedding matrix (vocab_size, d_model),
initialized from jax.random.normal. The forward is a single index:
class MyEmbed(nnx.Module):
def __init__(self, vocab_size, d_model, rngs):
key = rngs.params()
self.embedding = nnx.Param(
jax.random.normal(key, (vocab_size, d_model))
)
def __call__(self, token_ids):
return self.embedding[token_ids] # gather along axis 0
self.embedding[token_ids] is advanced integer indexing — for an
input of shape (T,) containing integer IDs, the output has shape
(T, d_model). Each row is embedding[token_ids[t]].
Initialization
Many production models use normal(0, 0.02) (BERT/GPT-2) instead of
standard normal. We use plain jax.random.normal here because it’s
the simplest and lets you focus on the lookup mechanics. In a real
model you’d swap to a calibrated init.
Casting integer ids
The harness passes inputs as JAX float arrays. Embedding lookups need integers — JAX does NOT auto-cast for indexing. Cast before lookup:
ids = token_ids.astype(jnp.int32)
out = model(ids)
Forgetting this is the #1 mistake. The error message (“indices must be integral”) is fortunately clear.
Linen contrast
# Linen — for contrast.
class MyEmbed(nn.Module):
vocab_size: int
d_model: int
@nn.compact
def __call__(self, ids):
E = self.param("embedding", nn.initializers.normal(0.02),
(self.vocab_size, self.d_model))
return E[ids]
The math is identical. nnx removes init+apply; the embedding
matrix is just an attribute of the module.
Common pitfalls
-
Float token ids. Cast to
jnp.int32(orint64) before lookup. -
Wrong embedding shape. Must be
(vocab_size, d_model), NOT(d_model, vocab_size). The first axis is the lookup axis. -
Out-of-range IDs.
embedding[5]whenvocab_size=4returns garbage on most backends (JAX wraps; some CPUs trap). Real models validate this upstream. - Bare matmul instead of indexing. Some implementations turn IDs into one-hot vectors and matmul with the embedding matrix. That’s mathematically equivalent but slower and memory-hostile; use indexing.
Problem
Write embed_forward(seed, token_ids, vocab_size, d_model):
-
Define
MyEmbed(nnx.Module)with onennx.Param: embedding matrix shape(vocab_size, d_model), initialized fromjax.random.normal(rngs.params(), shape). -
__call__(self, token_ids): returnself.embedding[token_ids]. -
Cast
token_idstojnp.int32before calling. -
Build with
nnx.Rngs(int(seed)), instantiate (vocab_size=int(vocab_size),d_model=int(d_model)), returnmodel(ids).reshape(-1).
Inputs:
-
seed: int (passed as float). -
token_ids: 1-D JAX array of lengthT. -
vocab_size,d_model: ints (passed as floats — cast).
Output: 1-D array of length T * d_model.
Hints
Sign in to attempt this problem and view the solution.