We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Head Attention Block
Implement multi-head attention with a residual connection and layer norm — the core building block of every transformer model.
Why multiple heads?
A single attention head projects all tokens into one query/key/value space. Multiple heads let the model attend to information from different representation subspaces simultaneously: one head might focus on syntactic relationships, another on coreference, another on positional proximity. The outputs are concatenated and projected back to the model dimension.
Math (Vaswani et al., “Attention Is All You Need”, NeurIPS 2017, §3.2)
For a single head $i$ with subspace dimension $d_{\text{head}} = d_{\text{model}} / h$:
$$\text{head}_i = \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_{\text{head}}}}\right) V_i$$
$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\ldots,\text{head}_h)\,W^O$$
This problem adds a residual connection and layer norm following the original (post-LN) transformer convention:
$$\text{output} = \text{LayerNorm}(x + \text{MultiHead}(x))$$
Post-LN vs pre-LN: The original 2017 paper normalises after adding the residual (post-LN, used here). Many modern implementations normalise before the attention sub-layer (pre-LN, e.g. GPT-2) for training stability. Both conventions appear in production code, so it’s worth knowing which is which.
Pipeline
Given x of shape (N, T, d_model) and weight matrices
w_q, w_k, w_v of shape (d_model, d_model) and w_o of
shape (d_model, d_model):
-
Project:
Q = x @ w_q,K = x @ w_k,V = x @ w_v— each(N, T, d_model). -
Reshape: split last dim
d_model→(num_heads, d_head), transpose to(N, num_heads, T, d_head). -
Per-head scaled dot-product:
scores = Q @ Kᵀ / sqrt(d_head)— shape(N, num_heads, T, T);attn = softmax(scores, dim=-1);per_head = attn @ V— shape(N, num_heads, T, d_head). -
Concat heads: transpose + reshape →
(N, T, d_model). -
Output projection + residual:
out = (concat @ w_o) + x. -
Layer norm over last dim with
eps=1e-5(no learned γ/β).
The weight matrices are passed in directly (not as nn.Linear modules) so
that tests are reproducible and framework-agnostic.
Inputs / Output
-
x: shape(N, T, d_model)— input token sequence. -
w_q, w_k, w_v: shape(d_model, d_model). -
w_o: shape(d_model, d_model). -
num_heads: int — must divided_modelevenly. -
Output: shape
(N, T, d_model).
Hints
Sign in to attempt this problem and view the solution.