medium primitives

NNX Implement Embed

Why this matters

Token embeddings are how every transformer turns discrete vocabulary indices into continuous vectors the rest of the model can do math on. BERT, GPT-2, GPT-3, LLaMA, T5 — every one of them starts with this layer. Embedding dimension d_model and vocabulary size vocab_size determine the bulk of the model’s parameter count for small models (most of GPT-2’s 1.5B parameters live in two embeddings).

Reimplementing it pins down what the layer actually IS: a learned (vocab_size, d_model) matrix, indexed by token IDs to produce embeddings. There’s no nonlinearity, no initialization scale magic in the spec — just a lookup table that gets updated by gradient descent.

API

One trainable parameter: the embedding matrix (vocab_size, d_model), initialized from jax.random.normal. The forward is a single index:

class MyEmbed(nnx.Module):
    def __init__(self, vocab_size, d_model, rngs):
        key = rngs.params()
        self.embedding = nnx.Param(
            jax.random.normal(key, (vocab_size, d_model))
        )

    def __call__(self, token_ids):
        return self.embedding[token_ids]   # gather along axis 0

self.embedding[token_ids] is advanced integer indexing — for an input of shape (T,) containing integer IDs, the output has shape (T, d_model). Each row is embedding[token_ids[t]].

Initialization

Many production models use normal(0, 0.02) (BERT/GPT-2) instead of standard normal. We use plain jax.random.normal here because it’s the simplest and lets you focus on the lookup mechanics. In a real model you’d swap to a calibrated init.

Casting integer ids

The harness passes inputs as JAX float arrays. Embedding lookups need integers — JAX does NOT auto-cast for indexing. Cast before lookup:

ids = token_ids.astype(jnp.int32)
out = model(ids)

Forgetting this is the #1 mistake. The error message (“indices must be integral”) is fortunately clear.

Linen contrast

# Linen — for contrast.
class MyEmbed(nn.Module):
    vocab_size: int
    d_model: int
    @nn.compact
    def __call__(self, ids):
        E = self.param("embedding", nn.initializers.normal(0.02),
                       (self.vocab_size, self.d_model))
        return E[ids]

The math is identical. nnx removes init+apply; the embedding matrix is just an attribute of the module.

Common pitfalls

  • Float token ids. Cast to jnp.int32 (or int64) before lookup.
  • Wrong embedding shape. Must be (vocab_size, d_model), NOT (d_model, vocab_size). The first axis is the lookup axis.
  • Out-of-range IDs. embedding[5] when vocab_size=4 returns garbage on most backends (JAX wraps; some CPUs trap). Real models validate this upstream.
  • Bare matmul instead of indexing. Some implementations turn IDs into one-hot vectors and matmul with the embedding matrix. That’s mathematically equivalent but slower and memory-hostile; use indexing.

Problem

Write embed_forward(seed, token_ids, vocab_size, d_model):

  1. Define MyEmbed(nnx.Module) with one nnx.Param: embedding matrix shape (vocab_size, d_model), initialized from jax.random.normal(rngs.params(), shape).
  2. __call__(self, token_ids): return self.embedding[token_ids].
  3. Cast token_ids to jnp.int32 before calling.
  4. Build with nnx.Rngs(int(seed)), instantiate (vocab_size=int(vocab_size), d_model=int(d_model)), return model(ids).reshape(-1).

Inputs:

  • seed: int (passed as float).
  • token_ids: 1-D JAX array of length T.
  • vocab_size, d_model: ints (passed as floats — cast).

Output: 1-D array of length T * d_model.

Hints

flax nnx embedding reimplementation

Sign in to attempt this problem and view the solution.