We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Transformer Encoder Block
Why this matters
The Transformer encoder block is the canonical “Lego brick” of every
encoder-style architecture: BERT, ViT, T5’s encoder, the encoder half
of the original “Attention Is All You Need.” Stack N of these and
you have an encoder. Once you can write one block from scratch, the
whole zoo of encoder-only models is composition.
A block is two sublayers — multi-head self-attention and a position-wise feed-forward network — each wrapped in a LayerNorm and a residual connection. That’s it. Everything else (heads, FFN width, activation choice) is hyperparameters around the same skeleton.
Pre-LN vs post-LN
The original 2017 paper used post-LN (x = LayerNorm(x + Sublayer(x))).
Modern implementations almost universally use pre-LN
(x = x + Sublayer(LayerNorm(x))) because it’s much easier to train
deep stacks: gradients flow through the residual without going
through the LayerNorm first. We use pre-LN here.
x -> LN1 -> MHA -> + (residual) -> LN2 -> FFN -> + (residual) -> out
\__________________/ \__________________/
Sublayer 1 Sublayer 2
The pieces
Five submodules, all simple nnx.Module attributes:
-
ln1:nnx.LayerNorm(d_model, rngs=rngs)— pre-attention norm. -
attn: a multi-head self-attention module (write your own MHA, like pos 32; fournnx.Linear(d_model, d_model)projections plus the reshape/SDPA dance). -
ln2:nnx.LayerNorm(d_model, rngs=rngs)— pre-FFN norm. -
ff1:nnx.Linear(d_model, d_ff, rngs=rngs)— first FFN linear, expanding to the wider hidden width. -
ff2:nnx.Linear(d_ff, d_model, rngs=rngs)— second FFN linear, projecting back.
The FFN is Dense -> ReLU -> Dense. (GELU is more common in practice;
ReLU is fine here for predictability and speed.)
Worked sketch
class EncoderBlock(nnx.Module):
def __init__(self, d_model, num_heads, d_ff, rngs):
self.ln1 = nnx.LayerNorm(d_model, rngs=rngs)
self.attn = MHA(d_model=d_model, num_heads=num_heads, rngs=rngs)
self.ln2 = nnx.LayerNorm(d_model, rngs=rngs)
self.ff1 = nnx.Linear(d_model, d_ff, rngs=rngs)
self.ff2 = nnx.Linear(d_ff, d_model, rngs=rngs)
def __call__(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff2(jax.nn.relu(self.ff1(self.ln2(x))))
return x
Compare with Linen, where you’d pass train=... and deterministic=...
through every submodule and read params from a params collection.
In nnx the block is just a class that owns five attributes; calling
it is model(x).
Why two LayerNorms?
Each sublayer (attention, FFN) gets its own LayerNorm because each
operates on a different geometry. Sharing one LN would couple their
pre-activations and hurt expressivity. This is why even the smallest
blocks have ln1 and ln2.
Why FFN width = 4 * d_model in production?
Convention. The original paper used d_ff = 4 * d_model (e.g.,
d_model=512, d_ff=2048). The MHA sublayer is bandwidth-limited
(matmul + softmax); the FFN is FLOP-limited. A wider FFN balances
the two. Here we just pass d_ff as a hyperparameter so you can
experiment.
Common pitfalls
-
Forgetting the residual. Without it, the whole point of the
block (gradient highway) disappears. Always
x = x + Sublayer(...). - Post-LN instead of pre-LN. Both work; we standardize on pre-LN. Make sure to apply LN to the input of each sublayer, not the output.
- Sharing one LN between sublayers. Two LNs, two normalizations.
- FFN with no nonlinearity. Two linear layers in a row with nothing between collapse to a single linear layer. The ReLU is load-bearing.
-
d_modelnot divisible bynum_heads. MHA asserts on this.
Problem
Write encoder_block_forward(seed, x, num_heads, d_model, d_ff):
-
Define an inner
MHA(nnx.Module)(same as pos 32) and anEncoderBlock(nnx.Module)withln1,attn,ln2,ff1,ff2attributes. -
__call__(x)doesx + attn(ln1(x))thenx + ff2(relu(ff1(ln2(x)))). -
Cast
num_heads,d_model,d_fffrom float to int. Build withnnx.Rngs(int(seed)). -
Return the output flattened:
out.reshape(-1).
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model,d_ff: ints (passed as floats).
Output: 1-D flattened (T * d_model,).
Hints
Sign in to attempt this problem and view the solution.