hard primitives

NNX Transformer Decoder Block

Why this matters

The decoder block is the asymmetric twin of the encoder block. The encoder has two sublayers (self-attention + FFN); the decoder has THREE (causal self-attention + cross-attention + FFN). The extra sublayer is what makes encoder-decoder Transformers — T5, original seq2seq Transformer, BART — possible: cross-attention is the channel by which the decoder reads from the encoder.

Decoder-only LMs (GPT, LLaMA) are a strict simplification: drop the cross-attention sublayer and you get a stack of causal-only blocks. So the decoder block is the most general member of the Transformer family — encoder block and GPT block are both special cases.

The three sublayers

Pre-LN, in order:

  1. Causal self-attention. Q, K, V all from x_q. The mask makes it autoregressive: each query position attends only to itself and earlier positions. This is the same as pos 33 (nnx-mha-causal).

  2. Cross-attention. Q from x_q (the decoder hidden state), K and V from x_kv (the encoder output). NO causal mask: the decoder can read any encoder position. This is the bridge between encoder and decoder.

  3. Position-wise FFN. Same as encoder: Dense -> ReLU -> Dense. Width d_ff (typically 4 * d_model).

Each sublayer is wrapped in LayerNorm (pre-LN) and a residual:

x_q -> LN1 -> SelfAttn(causal=True) -> + residual
    -> LN2 -> CrossAttn(Q=from_x_q, K=V=x_kv)   -> + residual
    -> LN3 -> FFN                               -> + residual
    -> out

One MHA module, two callsites

The trick that simplifies the implementation: write ONE MHA module with a causal: bool flag and (x_q, x_kv) signature. Self-attention calls it with (x, x, causal=True). Cross-attention calls it with (x_q, x_kv, causal=False). The “self” in self-attention is just “the same tensor for Q and K/V.”

class MHA(nnx.Module):
    # __init__: q_proj, k_proj, v_proj, out_proj — same as pos 32.

    def __call__(self, x_q, x_kv, causal=False):
        Tq, _ = x_q.shape
        Tk, _ = x_kv.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x_q).reshape(Tq, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        if causal:
            mask = jnp.tril(jnp.ones((Tq, Tk)))
            scores = jnp.where(mask == 0, -1e9, scores)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(Tq, H * Dh)
        return self.out_proj(concat)

Now the decoder block has TWO MHA submodule attributes (one for self, one for cross), and the only difference between them is which tensors you feed and the causal flag.

Worked sketch

class DecoderBlock(nnx.Module):
    def __init__(self, d_model, num_heads, d_ff, rngs):
        self.ln1 = nnx.LayerNorm(d_model, rngs=rngs)
        self.self_attn = MHA(d_model, num_heads, rngs=rngs)
        self.ln2 = nnx.LayerNorm(d_model, rngs=rngs)
        self.cross_attn = MHA(d_model, num_heads, rngs=rngs)
        self.ln3 = nnx.LayerNorm(d_model, rngs=rngs)
        self.ff1 = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.ff2 = nnx.Linear(d_ff, d_model, rngs=rngs)

    def __call__(self, x_q, x_kv):
        h = self.ln1(x_q)
        x_q = x_q + self.self_attn(h, h, causal=True)
        x_q = x_q + self.cross_attn(self.ln2(x_q), x_kv, causal=False)
        x_q = x_q + self.ff2(jax.nn.relu(self.ff1(self.ln3(x_q))))
        return x_q

Three LNs, two MHAs, two Dense layers — seven submodule attributes. Three residuals.

Why no causal mask on cross-attention?

The decoder reading from the encoder isn’t autoregressive — at every decoder step, the encoder output is already complete. There’s nothing to “leak from the future” in the encoder; the encoder ran first. Causal masking only applies within the self-attention pass over the decoder’s own positions.

Common pitfalls

  • Cross-attention with (x_q, x_q). That’s just self-attention again — you’ve ignored the encoder. Q must come from x_q, K/V from x_kv.
  • Causal mask on cross-attention. Wrong; cross-attention sees the whole encoder output at every step.
  • Self-attention without causal mask. During training, the decoder would learn to copy from future positions. Always mask self-attention in a decoder.
  • Sharing one LayerNorm. Three sublayers, three LayerNorms.
  • Same MHA instance for self and cross. They have separate learnable projections — must be two distinct nnx.Module attributes.

Problem

Write decoder_block_forward(seed, x_q, x_kv, num_heads, d_model, d_ff):

  1. Inner MHA(nnx.Module) taking (x_q, x_kv, causal: bool). Build the lower-triangular mask from (Tq, Tk) when causal=True.
  2. DecoderBlock(nnx.Module) with three LayerNorms, two MHAs (self_attn, cross_attn), and two nnx.Linears for the FFN.
  3. __call__(x_q, x_kv) applies the three pre-LN + residual sublayers.
  4. Cast int hyperparameters from float; build with nnx.Rngs(int(seed)).
  5. Return model(x_q, x_kv).reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x_q: 2-D (Tq, d_model).
  • x_kv: 2-D (Tk, d_model).
  • num_heads, d_model, d_ff: ints (passed as floats).

Output: 1-D flattened (Tq * d_model,).

Hints

flax nnx transformer decoder cross-attention architecture

Sign in to attempt this problem and view the solution.