hard primitives

NNX Mini-BERT

Why this matters

BERT is GPT’s encoder-only sibling: same general layout (token embed

  • position embed + N Transformer blocks + final LayerNorm), but two crucial differences:
  1. No causal mask. Each token can attend to every other token, both directions. This is what makes BERT a bidirectional encoder — useful for fill-in-the-blank tasks, classification, and extracting representations of an entire input.
  2. No output head. BERT returns its final-layer hidden states, not vocab logits. Downstream tasks add their own task-specific head (CLS-token classifier, span-prediction MLP, etc.).

Once you have the GPT block from pos 43, switching to BERT is two edits: drop the causal mask in MHA, return hidden states instead of logits.

What “bidirectional” means

In a causal/autoregressive model, position i attends only to positions 0..i. In a bidirectional encoder, position i attends to ALL positions 0..T-1. The attention pattern is dense (full T x T matrix), not lower-triangular. This is the same as the encoder block from pos 41 — no mask at all.

Why it works for BERT: the training objective is masked language modeling. Some token positions are replaced with a [MASK] token, and the model predicts the original from BIDIRECTIONAL context. Causal masking would prevent the model from using right-side context, defeating the entire point.

Why it would be a disaster for GPT: GPT trains by next-token prediction over the whole sequence in parallel. Without the causal mask, the model just copies the next token directly from K/V (it’s sitting right there in the sequence). Causality is a hard architectural constraint — for autoregressive training only.

The architecture

token_ids (T,) int32
  |
  v
nnx.Embed(vocab_size, d_model)       -> (T, d_model)
  + learned pos_embed[:T]            -> (T, d_model)
  |
  v
[EncoderBlock] x num_layers          -> (T, d_model)
  |   each block (no mask):
  |     x = x + attn(ln1(x))           # bidirectional self-attn
  |     x = x + ff2(relu(ff1(ln2(x)))) # FFN
  v
nnx.LayerNorm(d_model)               -> (T, d_model)   final hidden

Output is (T, d_model) — the per-token contextualized representations. Flatten to (T * d_model,) for the harness.

What the user adds on top

Any of these task heads, depending on use case:

  • Classification (e.g., sentiment): take hidden state at position 0 (the [CLS] token), pass through nnx.Linear(d_model, num_classes).
  • Token tagging (e.g., NER): apply nnx.Linear(d_model, num_tags) to every position.
  • Masked LM: a tied head as in GPT (hidden @ embed.embedding.T) computes vocab logits at every position.

Mini-BERT itself outputs the contextualized representations; the head is a separate module wrapped around it.

Worked sketch

class EncoderBlock(nnx.Module):
    # Identical to pos 41 — pre-LN, two sublayers, no mask.
    def __call__(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff2(jax.nn.relu(self.ff1(self.ln2(x))))
        return x

class MiniBERT(nnx.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_T, rngs):
        self.embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
        self.pos_embed = nnx.Param(jnp.zeros((max_T, d_model)))
        self.blocks = nnx.List([
            EncoderBlock(d_model, num_heads, d_ff, rngs=rngs)
            for _ in range(num_layers)
        ])
        self.ln_f = nnx.LayerNorm(d_model, rngs=rngs)

    def __call__(self, token_ids):
        T = token_ids.shape[0]
        x = self.embed(token_ids) + self.pos_embed.value[:T]
        for block in self.blocks:
            x = block(x)
        return self.ln_f(x)

Compare with MiniGPT from pos 43: only the inner attention’s mask line is removed, and there’s no @ embed.embedding.value.T at the end.

Common pitfalls

  • Leaving the causal mask in. Then you’ve built a “decoder-only with no head” — not BERT. EncoderBlock uses unmasked MHA.
  • Returning logits instead of hidden states. BERT’s __call__ stops at the LayerNorm. Don’t multiply by the embedding matrix.
  • Forgetting nnx.List. Same trap as MiniGPT — plain Python lists are static; submodule lists must be wrapped in nnx.List.
  • Casting token_ids. Same as before: astype(jnp.int32) before the embedding lookup.

Problem

Write mini_bert_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):

  1. Inner unmasked MHA(nnx.Module) (same as pos 32 — no mask line).
  2. EncoderBlock(nnx.Module) with two LayerNorms, the unmasked attn, and a two-Dense FFN with ReLU. Pre-LN sublayer ordering.
  3. MiniBERT(nnx.Module) with embed, pos_embed, blocks (nnx.List), ln_f. Forward: embed + pos, run blocks, final LN. Return hidden states (T, d_model).
  4. Cast hyperparameters to int. Cast token_ids to int32.
  5. Return hidden.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • token_ids: 1-D (T,) of indices (passed as floats).
  • vocab_size, d_model, num_heads, d_ff, num_layers: ints (passed as floats).

Output: 1-D flattened hidden states (T * d_model,).

Hints

flax nnx transformer bert encoder-only architecture

Sign in to attempt this problem and view the solution.