medium primitives

Token Embedding with Flax

Why this matters

Every Transformer language model starts the same way: a stream of integer token IDs in, a stream of dense vectors out. That step is the token embedding β€” a learned lookup table of shape (vocab_size, d_model), indexed row-by-row by the input IDs.

Conceptually it’s W[token_ids], where W is a (V, D) matrix. Each row is the learned vector for one vocabulary item. Training updates only the rows that get used in a batch, which is why even huge embedding tables (V β‰ˆ 50k, D β‰ˆ 4096) train efficiently.

Equivalently: an embedding lookup is matrix multiplication by a one-hot vector. one_hot(id) @ W = W[id]. The lookup is just the fast path.

The Flax API

embed = nn.Embed(num_embeddings=V, features=D)
out = embed(token_ids)        # (T,) int β†’ (T, D) float
  • num_embeddings=V: vocabulary size β€” the table has V rows.
  • features=D: embedding dimension β€” each row is a D-vector.
  • Default initialiser is nn.initializers.normal(stddev=1.0). Production models often override this (e.g. normal(stddev=0.02) Γ  la GPT/BERT).

Input token_ids MUST be an integer dtype. Float IDs raise a cryptic error inside the lookup. Cast with .astype(jnp.int32) if needed.

Worked example

import jax, jax.numpy as jnp
import flax.linen as nn

embed = nn.Embed(num_embeddings=4, features=3)
rng = jax.random.PRNGKey(0)
ids = jnp.array([0, 2, 1])

params = embed.init(rng, ids)   # {'params': {'embedding': (4, 3)}}
out = embed.apply(params, ids)  # (3, 3) β€” rows 0, 2, 1 of the table

Note the param tree: params['params']['embedding'] is the (V, D) matrix. After training you’d often save/load just this tensor.

Common pitfalls

  • Float token_ids: cast to jnp.int32 (the test harness ships ints as floats β€” the cast is mandatory).
  • IDs out of range: embed(jnp.array([V])) is undefined behaviour (silently returns garbage on GPU/TPU, may segfault on CPU). Always ensure 0 <= id < V.
  • Confusing features with num_embeddings: features is D (the per-token vector size), num_embeddings is V (how many distinct tokens exist). Mixing them up is a real bug.
  • Forgetting .reshape(-1) when the harness expects flat output.

Problem

Implement token_embed_forward(seed, token_ids, vocab_size, d_model):

  1. Cast vocab_size, d_model to int.
  2. Cast token_ids to jnp.int32 (it arrives as float).
  3. Build nn.Embed(num_embeddings=V, features=D).
  4. Init with jax.random.PRNGKey(seed) and apply on token_ids.
  5. Return the output flattened with .reshape(-1).

Inputs:

  • seed: int.
  • token_ids: 1-D float array (cast inside).
  • vocab_size: int V.
  • d_model: int D.

Output: 1-D array of length T * D (the flattened (T, D) lookup).

Hints

flax embedding transformers

Sign in to attempt this problem and view the solution.