hard end_to_end

LLaMA-Style Transformer Block

Implement a LLaMA-style transformer block: pre-RMSNorm, RoPE on Q/K, causal multi-head attention, and a SwiGLU feed-forward network. This is the exact sub-layer stack used in LLaMA-1, LLaMA-2, LLaMA-3, Mistral, and most open-weight LLMs since 2023.

Innovations vs the Original Transformer

Component Original (Vaswani 2017) LLaMA
Norm Post-LayerNorm Pre-RMSNorm
Position Absolute sinusoidal RoPE on Q/K
FFN ReLU / GELU (2 matrices) SwiGLU (3 matrices)
Attention Bidirectional (encoder) Causal (decoder)

Pre-RMSNorm

Apply normalisation before each sub-layer (attention and FFN), not after. This “pre-norm” layout gives better gradient flow during training.

RMSNorm skips mean-centering: it only rescales by the root-mean-square. No learned gamma — just x / sqrt(mean(x²) + eps) with eps = 1e-6.

RoPE (Rotary Position Embedding)

Applied to Q and K only immediately after the multi-head reshape. Uses the same even/odd pair rotation as in transformer-with-rope.

Causal Mask

The same lower-triangular mask as in causal-self-attention-block: positions (i, j) with j > i are set to -1e9 before softmax.

SwiGLU FFN

Three weight matrices instead of two:

gate   = SiLU(norm @ w_gate)    # SiLU(x) = x * sigmoid(x)
up     = norm @ w_up
ffn_out = (gate * up) @ w_down  # element-wise gate, then down-project

The gated unit lets the network learn to suppress irrelevant features at each position before the down-projection.

Full Pipeline

  1. norm1 = RMSNorm(x).
  2. Q, K, V = norm1 @ w_q / w_k / w_v — reshape to (N, num_heads, T, d_head).
  3. Apply RoPE to Q and K (not V).
  4. Causal scaled-dot-product: build (T, T) lower-triangular mask, fill upper with -1e9, softmax, weighted sum over V.
  5. Concat heads + output projection: attn_out = concat @ w_o.
  6. Residual: x = x + attn_out.
  7. norm2 = RMSNorm(x).
  8. SwiGLU: gate = SiLU(norm2 @ w_gate), up = norm2 @ w_up, ffn_out = (gate * up) @ w_down.
  9. Residual: x = x + ffn_out.
  10. Return x.

Inputs / Output

  • x: (N, T, d_model).
  • w_q, w_k, w_v, w_o: (d_model, d_model).
  • w_gate, w_up: (d_model, d_ff).
  • w_down: (d_ff, d_model).
  • num_heads: int; d_head = d_model / num_heads must be even (RoPE).
  • freqs_cos: (T, d_head/2).
  • freqs_sin: (T, d_head/2).
  • Output: (N, T, d_model).

Hints

causal-lm rope rmsnorm swiglu

Sign in to attempt this problem and view the solution.