We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Mini-GPT
Why this matters
GPT, in its simplest form, is: token embedding + learned position embedding, then a stack of identical pre-LN causal Transformer blocks (causal self-attention + FFN), then a final LayerNorm and a tied output head that maps back to vocab logits. Everything else (LLaMA, Mistral, GPT-NeoX) is variations on the same skeleton — different norm flavors (RMSNorm), different positional schemes (RoPE, ALiBi), different FFN gates (SwiGLU). This problem builds the skeleton.
The architecture, top to bottom
token_ids (T,) int32
|
v
nnx.Embed(vocab_size, d_model) -> (T, d_model)
+ learned pos_embed[:T] -> (T, d_model)
|
v
[GPTBlock] x num_layers -> (T, d_model)
| each block:
| x = x + attn(ln1(x)) # causal self-attn
| x = x + ff2(relu(ff1(ln2(x)))) # FFN
v
nnx.LayerNorm(d_model) -> (T, d_model)
|
v
@ embed.embedding.value.T -> (T, vocab_size) logits
Five attribute groups: embed, pos_embed, blocks, ln_f, and
the implicit tied head (which is just a transpose of embed.embedding,
no separate parameters).
Embedding two ways: token + position
nnx.Embed(vocab_size, d_model, rngs=rngs) is a built-in lookup
table — call it on int token ids to get the corresponding rows.
Position embeddings are simpler: a single learned matrix (max_T, d_model)
initialized to zeros (or normal-noise; zeros is fine for tiny tests).
Slice pos_embed.value[:T] and add to the token embeddings. Storing
it as nnx.Param(jnp.zeros((max_T, d_model))) makes it trainable.
Stacking blocks: nnx.List
Plain Python lists in nnx 0.12.6 are treated as STATIC by default, which fails for a list of modules. Two equivalent fixes:
self.blocks = nnx.List([GPTBlock(...) for _ in range(num_layers)])
# or
self.blocks: list = nnx.data() # then assign a plain list
nnx.List is the cleanest option: it wraps the list as a data
pytree, so split/merge see all the parameters inside. Iterate with
a normal Python for block in self.blocks: loop.
Tied output head
The classical GPT trick: the output projection’s weights ARE the
embedding matrix, transposed. So logits = x @ embed.embedding.T
has shape (T, vocab_size), and the only “head” parameter is shared
with the input embedding. Saves vocab_size * d_model parameters
and is a small but real regularizer.
Read the embedding matrix off the module:
embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
embed.embedding # the nnx.Param wrapper
embed.embedding.value # the underlying (vocab_size, d_model) array
Then logits = x @ embed.embedding.value.T. Note the .value
unwrap: @ (matmul) on the nnx.Param wrapper auto-unwraps for
arithmetic, but .T is an attribute lookup and must go through
.value first.
Casting token_ids
The harness passes token_ids as a float array (the only numeric
payload it knows). Cast inside the function: ids = token_ids.astype(jnp.int32).
Without this, nnx.Embed fails because lookup expects integer
indices.
Worked sketch
class GPTBlock(nnx.Module):
# __init__ sets ln1, attn (causal MHA), ln2, ff1, ff2.
def __call__(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff2(jax.nn.relu(self.ff1(self.ln2(x))))
return x
class MiniGPT(nnx.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_T, rngs):
self.embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
self.pos_embed = nnx.Param(jnp.zeros((max_T, d_model)))
self.blocks = nnx.List([
GPTBlock(d_model, num_heads, d_ff, rngs=rngs)
for _ in range(num_layers)
])
self.ln_f = nnx.LayerNorm(d_model, rngs=rngs)
def __call__(self, token_ids):
T = token_ids.shape[0]
x = self.embed(token_ids) + self.pos_embed.value[:T]
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return x @ self.embed.embedding.value.T
Common pitfalls
-
Plain
self.blocks = [...]. Withoutnnx.List, nnx treats it as a static (non-data) attribute and the param tree is empty. Pass throughnnx.List(...). -
Forgetting to cast
token_idsto int.nnx.Embedneeds integer indices; floats raise an error or silently produce zeros. -
Slicing position embeddings beyond
max_T.pos_embed[:T]requiresmax_T >= T. Passmax_T = Tfor tests, or use a generous max length in production. -
Not transposing for the tied head.
x @ embed.embedding.valueis(T, d_model) @ (vocab_size, d_model)— shape mismatch. Need.T. -
Skipping the final LayerNorm. GPT has a
ln_fBEFORE the output projection. Without it, the logits are unnormalized and training is harder (in inference, it changes the distribution).
Problem
Write mini_gpt_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):
-
Inner
CausalMHA(nnx.Module)(causal self-attention, like pos 33). -
GPTBlock(nnx.Module)with two LayerNorms, the causal attn, and a two-Dense FFN with ReLU. Pre-LN sublayer ordering. -
MiniGPT(nnx.Module)withembed,pos_embed(zeros init, shape(max_T, d_model)),blocks(nnx.ListofGPTBlock),ln_f. Forward: embed + pos, run blocks, final LN, tied head. -
Cast all hyperparameters to int. Cast
token_idsto int32. -
Return
logits.reshape(-1).
Inputs:
-
seed: int (passed as float). -
token_ids: 1-D(T,)of token indices (passed as floats). -
vocab_size,d_model,num_heads,d_ff,num_layers: ints (passed as floats).
Output: 1-D flattened logits (T * vocab_size,).
Hints
Sign in to attempt this problem and view the solution.