We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Mini-BERT
Why this matters
BERT is GPT’s encoder-only sibling: same general layout (token embed
- position embed + N Transformer blocks + final LayerNorm), but two crucial differences:
- No causal mask. Each token can attend to every other token, both directions. This is what makes BERT a bidirectional encoder — useful for fill-in-the-blank tasks, classification, and extracting representations of an entire input.
- No output head. BERT returns its final-layer hidden states, not vocab logits. Downstream tasks add their own task-specific head (CLS-token classifier, span-prediction MLP, etc.).
Once you have the GPT block from pos 43, switching to BERT is two edits: drop the causal mask in MHA, return hidden states instead of logits.
What “bidirectional” means
In a causal/autoregressive model, position i attends only to
positions 0..i. In a bidirectional encoder, position i attends
to ALL positions 0..T-1. The attention pattern is dense
(full T x T matrix), not lower-triangular. This is the same as
the encoder block from pos 41 — no mask at all.
Why it works for BERT: the training objective is masked language
modeling. Some token positions are replaced with a [MASK] token,
and the model predicts the original from BIDIRECTIONAL context.
Causal masking would prevent the model from using right-side
context, defeating the entire point.
Why it would be a disaster for GPT: GPT trains by next-token prediction over the whole sequence in parallel. Without the causal mask, the model just copies the next token directly from K/V (it’s sitting right there in the sequence). Causality is a hard architectural constraint — for autoregressive training only.
The architecture
token_ids (T,) int32
|
v
nnx.Embed(vocab_size, d_model) -> (T, d_model)
+ learned pos_embed[:T] -> (T, d_model)
|
v
[EncoderBlock] x num_layers -> (T, d_model)
| each block (no mask):
| x = x + attn(ln1(x)) # bidirectional self-attn
| x = x + ff2(relu(ff1(ln2(x)))) # FFN
v
nnx.LayerNorm(d_model) -> (T, d_model) final hidden
Output is (T, d_model) — the per-token contextualized
representations. Flatten to (T * d_model,) for the harness.
What the user adds on top
Any of these task heads, depending on use case:
-
Classification (e.g., sentiment): take hidden state at position
0 (the
[CLS]token), pass throughnnx.Linear(d_model, num_classes). -
Token tagging (e.g., NER): apply
nnx.Linear(d_model, num_tags)to every position. -
Masked LM: a tied head as in GPT (
hidden @ embed.embedding.T) computes vocab logits at every position.
Mini-BERT itself outputs the contextualized representations; the head is a separate module wrapped around it.
Worked sketch
class EncoderBlock(nnx.Module):
# Identical to pos 41 — pre-LN, two sublayers, no mask.
def __call__(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff2(jax.nn.relu(self.ff1(self.ln2(x))))
return x
class MiniBERT(nnx.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_T, rngs):
self.embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
self.pos_embed = nnx.Param(jnp.zeros((max_T, d_model)))
self.blocks = nnx.List([
EncoderBlock(d_model, num_heads, d_ff, rngs=rngs)
for _ in range(num_layers)
])
self.ln_f = nnx.LayerNorm(d_model, rngs=rngs)
def __call__(self, token_ids):
T = token_ids.shape[0]
x = self.embed(token_ids) + self.pos_embed.value[:T]
for block in self.blocks:
x = block(x)
return self.ln_f(x)
Compare with MiniGPT from pos 43: only the inner attention’s mask
line is removed, and there’s no @ embed.embedding.value.T at the
end.
Common pitfalls
-
Leaving the causal mask in. Then you’ve built a “decoder-only
with no head” — not BERT.
EncoderBlockuses unmasked MHA. -
Returning logits instead of hidden states. BERT’s
__call__stops at the LayerNorm. Don’t multiply by the embedding matrix. -
Forgetting
nnx.List. Same trap as MiniGPT — plain Python lists are static; submodule lists must be wrapped innnx.List. -
Casting
token_ids. Same as before:astype(jnp.int32)before the embedding lookup.
Problem
Write mini_bert_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):
-
Inner unmasked
MHA(nnx.Module)(same as pos 32 — no mask line). -
EncoderBlock(nnx.Module)with two LayerNorms, the unmasked attn, and a two-Dense FFN with ReLU. Pre-LN sublayer ordering. -
MiniBERT(nnx.Module)withembed,pos_embed,blocks(nnx.List),ln_f. Forward: embed + pos, run blocks, final LN. Return hidden states(T, d_model). -
Cast hyperparameters to int. Cast
token_idsto int32. -
Return
hidden.reshape(-1).
Inputs:
-
seed: int (passed as float). -
token_ids: 1-D(T,)of indices (passed as floats). -
vocab_size,d_model,num_heads,d_ff,num_layers: ints (passed as floats).
Output: 1-D flattened hidden states (T * d_model,).
Hints
Sign in to attempt this problem and view the solution.