hard primitives

NNX Transformer Encoder Block

Why this matters

The Transformer encoder block is the canonical “Lego brick” of every encoder-style architecture: BERT, ViT, T5’s encoder, the encoder half of the original “Attention Is All You Need.” Stack N of these and you have an encoder. Once you can write one block from scratch, the whole zoo of encoder-only models is composition.

A block is two sublayers — multi-head self-attention and a position-wise feed-forward network — each wrapped in a LayerNorm and a residual connection. That’s it. Everything else (heads, FFN width, activation choice) is hyperparameters around the same skeleton.

Pre-LN vs post-LN

The original 2017 paper used post-LN (x = LayerNorm(x + Sublayer(x))). Modern implementations almost universally use pre-LN (x = x + Sublayer(LayerNorm(x))) because it’s much easier to train deep stacks: gradients flow through the residual without going through the LayerNorm first. We use pre-LN here.

x -> LN1 -> MHA -> + (residual) -> LN2 -> FFN -> + (residual) -> out
    \__________________/         \__________________/
          Sublayer 1                    Sublayer 2

The pieces

Five submodules, all simple nnx.Module attributes:

  • ln1: nnx.LayerNorm(d_model, rngs=rngs) — pre-attention norm.
  • attn: a multi-head self-attention module (write your own MHA, like pos 32; four nnx.Linear(d_model, d_model) projections plus the reshape/SDPA dance).
  • ln2: nnx.LayerNorm(d_model, rngs=rngs) — pre-FFN norm.
  • ff1: nnx.Linear(d_model, d_ff, rngs=rngs) — first FFN linear, expanding to the wider hidden width.
  • ff2: nnx.Linear(d_ff, d_model, rngs=rngs) — second FFN linear, projecting back.

The FFN is Dense -> ReLU -> Dense. (GELU is more common in practice; ReLU is fine here for predictability and speed.)

Worked sketch

class EncoderBlock(nnx.Module):
    def __init__(self, d_model, num_heads, d_ff, rngs):
        self.ln1 = nnx.LayerNorm(d_model, rngs=rngs)
        self.attn = MHA(d_model=d_model, num_heads=num_heads, rngs=rngs)
        self.ln2 = nnx.LayerNorm(d_model, rngs=rngs)
        self.ff1 = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.ff2 = nnx.Linear(d_ff, d_model, rngs=rngs)

    def __call__(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff2(jax.nn.relu(self.ff1(self.ln2(x))))
        return x

Compare with Linen, where you’d pass train=... and deterministic=... through every submodule and read params from a params collection. In nnx the block is just a class that owns five attributes; calling it is model(x).

Why two LayerNorms?

Each sublayer (attention, FFN) gets its own LayerNorm because each operates on a different geometry. Sharing one LN would couple their pre-activations and hurt expressivity. This is why even the smallest blocks have ln1 and ln2.

Why FFN width = 4 * d_model in production?

Convention. The original paper used d_ff = 4 * d_model (e.g., d_model=512, d_ff=2048). The MHA sublayer is bandwidth-limited (matmul + softmax); the FFN is FLOP-limited. A wider FFN balances the two. Here we just pass d_ff as a hyperparameter so you can experiment.

Common pitfalls

  • Forgetting the residual. Without it, the whole point of the block (gradient highway) disappears. Always x = x + Sublayer(...).
  • Post-LN instead of pre-LN. Both work; we standardize on pre-LN. Make sure to apply LN to the input of each sublayer, not the output.
  • Sharing one LN between sublayers. Two LNs, two normalizations.
  • FFN with no nonlinearity. Two linear layers in a row with nothing between collapse to a single linear layer. The ReLU is load-bearing.
  • d_model not divisible by num_heads. MHA asserts on this.

Problem

Write encoder_block_forward(seed, x, num_heads, d_model, d_ff):

  1. Define an inner MHA(nnx.Module) (same as pos 32) and an EncoderBlock(nnx.Module) with ln1, attn, ln2, ff1, ff2 attributes.
  2. __call__(x) does x + attn(ln1(x)) then x + ff2(relu(ff1(ln2(x)))).
  3. Cast num_heads, d_model, d_ff from float to int. Build with nnx.Rngs(int(seed)).
  4. Return the output flattened: out.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model, d_ff: ints (passed as floats).

Output: 1-D flattened (T * d_model,).

Hints

flax nnx transformer encoder architecture

Sign in to attempt this problem and view the solution.