hard primitives

Mini BERT β€” Encoder-Only Hidden States

Why this matters

BERT (Devlin et al., 2018) flipped the script: instead of predicting the next token (GPT) it predicts MASKED tokens, using bidirectional attention. That’s only possible with an encoder-only Transformer β€” no causal mask. Each token attends to every other token in both directions.

BERT became the workhorse for sentence embedding, classification, NER, QA β€” anywhere you want a representation of a sequence rather than to generate a continuation.

Architecture

ids (T,)
β†’ Embed(V, D)        β†’ tok (T, D)
β†’ + pos_embed        β†’ x   (T, D)
β†’ encoder block Γ— N  β†’ x   (T, D)        # bidir self-attn + FFN
β†’ x (final hidden states)

Compared to GPT (pos 42):

  • NO causal mask β€” bidirectional attention. Every position sees every other position.
  • NO output head β€” return the hidden states directly. Real BERT has task-specific heads (Dense(2) for binary classification, Dense(V) for masked-LM, etc.) added on TOP of the encoder; the encoder itself outputs (T, D).

The original BERT also has a token-type / segment embedding for sentence-pair tasks. We omit it here for clarity β€” the core architecture is the encoder stack.

Why bidirectional

Causal attention forces position t to only see ≀ t. That’s correct for left-to-right generation. But for understanding tasks (β€œwhat is this sentence about?”), every position should see EVERY other position β€” past and future. Drop the mask, train on a masked- LM objective, get a model whose representations encode bi-directional context.

Worked walk-through

With V=16, D=8, H=2, d_ff=16, L=2, T=4:

  1. tok = embed(ids) β†’ (4, 8).
  2. pos = pos_embed[:T] β†’ (4, 8). x = tok + pos.
  3. Block 1: x = x + MHA(LN(x)) (no mask); x = x + FFN(LN(x)).
  4. Block 2: same.
  5. Return x β†’ (4, 8). Flatten to (32,).

Common pitfalls

  • Adding a causal mask by accident: defeats the whole point. Just call MHA(h) β€” Flax’s default is no mask = bi-directional.
  • Adding a final output projection: BERT-base outputs hidden states. The classifier head is a separate downstream module; for this problem we return the encoder output.
  • Skipping the position embedding: same caveat as GPT β€” without it, tokens at different positions are indistinguishable.
  • Using ReLU vs GeLU: original BERT uses GeLU, but for simplicity we use ReLU here. Architecturally equivalent.

Problem

Implement mini_bert_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):

  1. Cast configs to int. Cast token_ids to jnp.int32.
  2. Build MiniBERT:
    • embed = nn.Embed(V, D). tok = embed(ids).
    • pos = self.param("pos_embed", normal(0.02), (T, D)).
    • x = tok + pos.
    • Loop N encoder blocks (bidirectional self-attn + FFN, Pre-LN).
    • Return x.
  3. Init/apply, return x.reshape(-1).

Reuse the EncoderBlock recipe from pos 39 (no mask).

Inputs:

  • seed: int.
  • token_ids: 1-D float (cast inside).
  • vocab_size, d_model, num_heads, d_ff, num_layers: ints.

Output: 1-D, length T * D β€” the flattened hidden states.

Hints

flax transformer bert encoder

Sign in to attempt this problem and view the solution.