We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Transformer Decoder Block
Why this matters
The decoder block is the asymmetric twin of the encoder block. The encoder has two sublayers (self-attention + FFN); the decoder has THREE (causal self-attention + cross-attention + FFN). The extra sublayer is what makes encoder-decoder Transformers — T5, original seq2seq Transformer, BART — possible: cross-attention is the channel by which the decoder reads from the encoder.
Decoder-only LMs (GPT, LLaMA) are a strict simplification: drop the cross-attention sublayer and you get a stack of causal-only blocks. So the decoder block is the most general member of the Transformer family — encoder block and GPT block are both special cases.
The three sublayers
Pre-LN, in order:
-
Causal self-attention. Q, K, V all from
x_q. The mask makes it autoregressive: each query position attends only to itself and earlier positions. This is the same as pos 33 (nnx-mha-causal). -
Cross-attention. Q from
x_q(the decoder hidden state), K and V fromx_kv(the encoder output). NO causal mask: the decoder can read any encoder position. This is the bridge between encoder and decoder. -
Position-wise FFN. Same as encoder:
Dense -> ReLU -> Dense. Widthd_ff(typically 4 * d_model).
Each sublayer is wrapped in LayerNorm (pre-LN) and a residual:
x_q -> LN1 -> SelfAttn(causal=True) -> + residual
-> LN2 -> CrossAttn(Q=from_x_q, K=V=x_kv) -> + residual
-> LN3 -> FFN -> + residual
-> out
One MHA module, two callsites
The trick that simplifies the implementation: write ONE MHA module
with a causal: bool flag and (x_q, x_kv) signature. Self-attention
calls it with (x, x, causal=True). Cross-attention calls it with
(x_q, x_kv, causal=False). The “self” in self-attention is just
“the same tensor for Q and K/V.”
class MHA(nnx.Module):
# __init__: q_proj, k_proj, v_proj, out_proj — same as pos 32.
def __call__(self, x_q, x_kv, causal=False):
Tq, _ = x_q.shape
Tk, _ = x_kv.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x_q).reshape(Tq, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
if causal:
mask = jnp.tril(jnp.ones((Tq, Tk)))
scores = jnp.where(mask == 0, -1e9, scores)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(Tq, H * Dh)
return self.out_proj(concat)
Now the decoder block has TWO MHA submodule attributes (one for self,
one for cross), and the only difference between them is which tensors
you feed and the causal flag.
Worked sketch
class DecoderBlock(nnx.Module):
def __init__(self, d_model, num_heads, d_ff, rngs):
self.ln1 = nnx.LayerNorm(d_model, rngs=rngs)
self.self_attn = MHA(d_model, num_heads, rngs=rngs)
self.ln2 = nnx.LayerNorm(d_model, rngs=rngs)
self.cross_attn = MHA(d_model, num_heads, rngs=rngs)
self.ln3 = nnx.LayerNorm(d_model, rngs=rngs)
self.ff1 = nnx.Linear(d_model, d_ff, rngs=rngs)
self.ff2 = nnx.Linear(d_ff, d_model, rngs=rngs)
def __call__(self, x_q, x_kv):
h = self.ln1(x_q)
x_q = x_q + self.self_attn(h, h, causal=True)
x_q = x_q + self.cross_attn(self.ln2(x_q), x_kv, causal=False)
x_q = x_q + self.ff2(jax.nn.relu(self.ff1(self.ln3(x_q))))
return x_q
Three LNs, two MHAs, two Dense layers — seven submodule attributes. Three residuals.
Why no causal mask on cross-attention?
The decoder reading from the encoder isn’t autoregressive — at every decoder step, the encoder output is already complete. There’s nothing to “leak from the future” in the encoder; the encoder ran first. Causal masking only applies within the self-attention pass over the decoder’s own positions.
Common pitfalls
-
Cross-attention with
(x_q, x_q). That’s just self-attention again — you’ve ignored the encoder. Q must come fromx_q, K/V fromx_kv. - Causal mask on cross-attention. Wrong; cross-attention sees the whole encoder output at every step.
- Self-attention without causal mask. During training, the decoder would learn to copy from future positions. Always mask self-attention in a decoder.
- Sharing one LayerNorm. Three sublayers, three LayerNorms.
-
Same
MHAinstance for self and cross. They have separate learnable projections — must be two distinctnnx.Moduleattributes.
Problem
Write decoder_block_forward(seed, x_q, x_kv, num_heads, d_model, d_ff):
-
Inner
MHA(nnx.Module)taking(x_q, x_kv, causal: bool). Build the lower-triangular mask from(Tq, Tk)whencausal=True. -
DecoderBlock(nnx.Module)with three LayerNorms, two MHAs (self_attn,cross_attn), and twonnx.Linears for the FFN. -
__call__(x_q, x_kv)applies the three pre-LN + residual sublayers. -
Cast int hyperparameters from float; build with
nnx.Rngs(int(seed)). -
Return
model(x_q, x_kv).reshape(-1).
Inputs:
-
seed: int (passed as float). -
x_q: 2-D(Tq, d_model). -
x_kv: 2-D(Tk, d_model). -
num_heads,d_model,d_ff: ints (passed as floats).
Output: 1-D flattened (Tq * d_model,).
Hints
Sign in to attempt this problem and view the solution.