hard primitives

NNX Mini-GPT

Why this matters

GPT, in its simplest form, is: token embedding + learned position embedding, then a stack of identical pre-LN causal Transformer blocks (causal self-attention + FFN), then a final LayerNorm and a tied output head that maps back to vocab logits. Everything else (LLaMA, Mistral, GPT-NeoX) is variations on the same skeleton — different norm flavors (RMSNorm), different positional schemes (RoPE, ALiBi), different FFN gates (SwiGLU). This problem builds the skeleton.

The architecture, top to bottom

token_ids (T,) int32
  |
  v
nnx.Embed(vocab_size, d_model)       -> (T, d_model)
  + learned pos_embed[:T]            -> (T, d_model)
  |
  v
[GPTBlock] x num_layers              -> (T, d_model)
  |   each block:
  |     x = x + attn(ln1(x))           # causal self-attn
  |     x = x + ff2(relu(ff1(ln2(x)))) # FFN
  v
nnx.LayerNorm(d_model)               -> (T, d_model)
  |
  v
@ embed.embedding.value.T            -> (T, vocab_size)   logits

Five attribute groups: embed, pos_embed, blocks, ln_f, and the implicit tied head (which is just a transpose of embed.embedding, no separate parameters).

Embedding two ways: token + position

nnx.Embed(vocab_size, d_model, rngs=rngs) is a built-in lookup table — call it on int token ids to get the corresponding rows.

Position embeddings are simpler: a single learned matrix (max_T, d_model) initialized to zeros (or normal-noise; zeros is fine for tiny tests). Slice pos_embed.value[:T] and add to the token embeddings. Storing it as nnx.Param(jnp.zeros((max_T, d_model))) makes it trainable.

Stacking blocks: nnx.List

Plain Python lists in nnx 0.12.6 are treated as STATIC by default, which fails for a list of modules. Two equivalent fixes:

self.blocks = nnx.List([GPTBlock(...) for _ in range(num_layers)])
# or
self.blocks: list = nnx.data()  # then assign a plain list

nnx.List is the cleanest option: it wraps the list as a data pytree, so split/merge see all the parameters inside. Iterate with a normal Python for block in self.blocks: loop.

Tied output head

The classical GPT trick: the output projection’s weights ARE the embedding matrix, transposed. So logits = x @ embed.embedding.T has shape (T, vocab_size), and the only “head” parameter is shared with the input embedding. Saves vocab_size * d_model parameters and is a small but real regularizer.

Read the embedding matrix off the module:

embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
embed.embedding         # the nnx.Param wrapper
embed.embedding.value   # the underlying (vocab_size, d_model) array

Then logits = x @ embed.embedding.value.T. Note the .value unwrap: @ (matmul) on the nnx.Param wrapper auto-unwraps for arithmetic, but .T is an attribute lookup and must go through .value first.

Casting token_ids

The harness passes token_ids as a float array (the only numeric payload it knows). Cast inside the function: ids = token_ids.astype(jnp.int32). Without this, nnx.Embed fails because lookup expects integer indices.

Worked sketch

class GPTBlock(nnx.Module):
    # __init__ sets ln1, attn (causal MHA), ln2, ff1, ff2.
    def __call__(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff2(jax.nn.relu(self.ff1(self.ln2(x))))
        return x

class MiniGPT(nnx.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_T, rngs):
        self.embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
        self.pos_embed = nnx.Param(jnp.zeros((max_T, d_model)))
        self.blocks = nnx.List([
            GPTBlock(d_model, num_heads, d_ff, rngs=rngs)
            for _ in range(num_layers)
        ])
        self.ln_f = nnx.LayerNorm(d_model, rngs=rngs)

    def __call__(self, token_ids):
        T = token_ids.shape[0]
        x = self.embed(token_ids) + self.pos_embed.value[:T]
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return x @ self.embed.embedding.value.T

Common pitfalls

  • Plain self.blocks = [...]. Without nnx.List, nnx treats it as a static (non-data) attribute and the param tree is empty. Pass through nnx.List(...).
  • Forgetting to cast token_ids to int. nnx.Embed needs integer indices; floats raise an error or silently produce zeros.
  • Slicing position embeddings beyond max_T. pos_embed[:T] requires max_T >= T. Pass max_T = T for tests, or use a generous max length in production.
  • Not transposing for the tied head. x @ embed.embedding.value is (T, d_model) @ (vocab_size, d_model) — shape mismatch. Need .T.
  • Skipping the final LayerNorm. GPT has a ln_f BEFORE the output projection. Without it, the logits are unnormalized and training is harder (in inference, it changes the distribution).

Problem

Write mini_gpt_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):

  1. Inner CausalMHA(nnx.Module) (causal self-attention, like pos 33).
  2. GPTBlock(nnx.Module) with two LayerNorms, the causal attn, and a two-Dense FFN with ReLU. Pre-LN sublayer ordering.
  3. MiniGPT(nnx.Module) with embed, pos_embed (zeros init, shape (max_T, d_model)), blocks (nnx.List of GPTBlock), ln_f. Forward: embed + pos, run blocks, final LN, tied head.
  4. Cast all hyperparameters to int. Cast token_ids to int32.
  5. Return logits.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • token_ids: 1-D (T,) of token indices (passed as floats).
  • vocab_size, d_model, num_heads, d_ff, num_layers: ints (passed as floats).

Output: 1-D flattened logits (T * vocab_size,).

Hints

flax nnx transformer gpt decoder-only architecture

Sign in to attempt this problem and view the solution.