We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Token Embedding with Flax
Why this matters
Every Transformer language model starts the same way: a stream of integer
token IDs in, a stream of dense vectors out. That step is the token
embedding β a learned lookup table of shape (vocab_size, d_model),
indexed row-by-row by the input IDs.
Conceptually itβs W[token_ids], where W is a (V, D) matrix. Each
row is the learned vector for one vocabulary item. Training updates only
the rows that get used in a batch, which is why even huge embedding
tables (V β 50k, D β 4096) train efficiently.
Equivalently: an embedding lookup is matrix multiplication by a one-hot
vector. one_hot(id) @ W = W[id]. The lookup is just the fast path.
The Flax API
embed = nn.Embed(num_embeddings=V, features=D)
out = embed(token_ids) # (T,) int β (T, D) float
-
num_embeddings=V: vocabulary size β the table has V rows. -
features=D: embedding dimension β each row is a D-vector. -
Default initialiser is
nn.initializers.normal(stddev=1.0). Production models often override this (e.g.normal(stddev=0.02)Γ la GPT/BERT).
Input token_ids MUST be an integer dtype. Float IDs raise a cryptic
error inside the lookup. Cast with .astype(jnp.int32) if needed.
Worked example
import jax, jax.numpy as jnp
import flax.linen as nn
embed = nn.Embed(num_embeddings=4, features=3)
rng = jax.random.PRNGKey(0)
ids = jnp.array([0, 2, 1])
params = embed.init(rng, ids) # {'params': {'embedding': (4, 3)}}
out = embed.apply(params, ids) # (3, 3) β rows 0, 2, 1 of the table
Note the param tree: params['params']['embedding'] is the (V, D)
matrix. After training youβd often save/load just this tensor.
Common pitfalls
-
Float
token_ids: cast tojnp.int32(the test harness ships ints as floats β the cast is mandatory). -
IDs out of range:
embed(jnp.array([V]))is undefined behaviour (silently returns garbage on GPU/TPU, may segfault on CPU). Always ensure0 <= id < V. -
Confusing
featureswithnum_embeddings:featuresisD(the per-token vector size),num_embeddingsisV(how many distinct tokens exist). Mixing them up is a real bug. -
Forgetting
.reshape(-1)when the harness expects flat output.
Problem
Implement token_embed_forward(seed, token_ids, vocab_size, d_model):
-
Cast
vocab_size,d_modeltoint. -
Cast
token_idstojnp.int32(it arrives as float). -
Build
nn.Embed(num_embeddings=V, features=D). -
Init with
jax.random.PRNGKey(seed)and apply ontoken_ids. -
Return the output flattened with
.reshape(-1).
Inputs:
-
seed: int. -
token_ids: 1-D float array (cast inside). -
vocab_size: int V. -
d_model: int D.
Output: 1-D array of length T * D (the flattened (T, D) lookup).
Hints
Sign in to attempt this problem and view the solution.