We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Mini BERT β Encoder-Only Hidden States
Why this matters
BERT (Devlin et al., 2018) flipped the script: instead of predicting the next token (GPT) it predicts MASKED tokens, using bidirectional attention. Thatβs only possible with an encoder-only Transformer β no causal mask. Each token attends to every other token in both directions.
BERT became the workhorse for sentence embedding, classification, NER, QA β anywhere you want a representation of a sequence rather than to generate a continuation.
Architecture
ids (T,)
β Embed(V, D) β tok (T, D)
β + pos_embed β x (T, D)
β encoder block Γ N β x (T, D) # bidir self-attn + FFN
β x (final hidden states)
Compared to GPT (pos 42):
- NO causal mask β bidirectional attention. Every position sees every other position.
-
NO output head β return the hidden states directly. Real BERT
has task-specific heads (
Dense(2)for binary classification,Dense(V)for masked-LM, etc.) added on TOP of the encoder; the encoder itself outputs(T, D).
The original BERT also has a token-type / segment embedding for sentence-pair tasks. We omit it here for clarity β the core architecture is the encoder stack.
Why bidirectional
Causal attention forces position t to only see β€ t. Thatβs
correct for left-to-right generation. But for understanding tasks
(βwhat is this sentence about?β), every position should see EVERY
other position β past and future. Drop the mask, train on a masked-
LM objective, get a model whose representations encode bi-directional
context.
Worked walk-through
With V=16, D=8, H=2, d_ff=16, L=2, T=4:
-
tok = embed(ids)β(4, 8). -
pos = pos_embed[:T]β(4, 8).x = tok + pos. -
Block 1:
x = x + MHA(LN(x))(no mask);x = x + FFN(LN(x)). - Block 2: same.
-
Return
xβ(4, 8). Flatten to(32,).
Common pitfalls
-
Adding a causal mask by accident: defeats the whole point.
Just call
MHA(h)β Flaxβs default is no mask = bi-directional. - Adding a final output projection: BERT-base outputs hidden states. The classifier head is a separate downstream module; for this problem we return the encoder output.
- Skipping the position embedding: same caveat as GPT β without it, tokens at different positions are indistinguishable.
- Using ReLU vs GeLU: original BERT uses GeLU, but for simplicity we use ReLU here. Architecturally equivalent.
Problem
Implement mini_bert_forward(seed, token_ids, vocab_size, d_model, num_heads, d_ff, num_layers):
-
Cast configs to int. Cast
token_idstojnp.int32. -
Build
MiniBERT:-
embed = nn.Embed(V, D).tok = embed(ids). -
pos = self.param("pos_embed", normal(0.02), (T, D)). -
x = tok + pos. - Loop N encoder blocks (bidirectional self-attn + FFN, Pre-LN).
-
Return
x.
-
-
Init/apply, return
x.reshape(-1).
Reuse the EncoderBlock recipe from pos 39 (no mask).
Inputs:
-
seed: int. -
token_ids: 1-D float (cast inside). -
vocab_size,d_model,num_heads,d_ff,num_layers: ints.
Output: 1-D, length T * D β the flattened hidden states.
Hints
Sign in to attempt this problem and view the solution.