We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
LLaMA-Style Transformer Block
Implement a LLaMA-style transformer block: pre-RMSNorm, RoPE on Q/K, causal multi-head attention, and a SwiGLU feed-forward network. This is the exact sub-layer stack used in LLaMA-1, LLaMA-2, LLaMA-3, Mistral, and most open-weight LLMs since 2023.
Innovations vs the Original Transformer
| Component | Original (Vaswani 2017) | LLaMA |
|---|---|---|
| Norm | Post-LayerNorm | Pre-RMSNorm |
| Position | Absolute sinusoidal | RoPE on Q/K |
| FFN | ReLU / GELU (2 matrices) | SwiGLU (3 matrices) |
| Attention | Bidirectional (encoder) | Causal (decoder) |
Pre-RMSNorm
Apply normalisation before each sub-layer (attention and FFN), not after. This “pre-norm” layout gives better gradient flow during training.
RMSNorm skips mean-centering: it only rescales by the root-mean-square.
No learned gamma — just x / sqrt(mean(x²) + eps) with eps = 1e-6.
RoPE (Rotary Position Embedding)
Applied to Q and K only immediately after the multi-head reshape. Uses
the same even/odd pair rotation as in transformer-with-rope.
Causal Mask
The same lower-triangular mask as in causal-self-attention-block: positions
(i, j) with j > i are set to -1e9 before softmax.
SwiGLU FFN
Three weight matrices instead of two:
gate = SiLU(norm @ w_gate) # SiLU(x) = x * sigmoid(x)
up = norm @ w_up
ffn_out = (gate * up) @ w_down # element-wise gate, then down-project
The gated unit lets the network learn to suppress irrelevant features at each position before the down-projection.
Full Pipeline
-
norm1 = RMSNorm(x). -
Q, K, V = norm1 @ w_q / w_k / w_v— reshape to(N, num_heads, T, d_head). - Apply RoPE to Q and K (not V).
-
Causal scaled-dot-product: build
(T, T)lower-triangular mask, fill upper with-1e9, softmax, weighted sum over V. -
Concat heads + output projection:
attn_out = concat @ w_o. -
Residual:
x = x + attn_out. -
norm2 = RMSNorm(x). -
SwiGLU:
gate = SiLU(norm2 @ w_gate),up = norm2 @ w_up,ffn_out = (gate * up) @ w_down. -
Residual:
x = x + ffn_out. -
Return
x.
Inputs / Output
-
x:(N, T, d_model). -
w_q, w_k, w_v, w_o:(d_model, d_model). -
w_gate, w_up:(d_model, d_ff). -
w_down:(d_ff, d_model). -
num_heads: int;d_head = d_model / num_headsmust be even (RoPE). -
freqs_cos:(T, d_head/2). -
freqs_sin:(T, d_head/2). -
Output:
(N, T, d_model).
Hints
Sign in to attempt this problem and view the solution.