hard end_to_end

Multi-Head Attention Block

Implement multi-head attention with a residual connection and layer norm — the core building block of every transformer model.

Why multiple heads?

A single attention head projects all tokens into one query/key/value space. Multiple heads let the model attend to information from different representation subspaces simultaneously: one head might focus on syntactic relationships, another on coreference, another on positional proximity. The outputs are concatenated and projected back to the model dimension.

Math (Vaswani et al., “Attention Is All You Need”, NeurIPS 2017, §3.2)

For a single head $i$ with subspace dimension $d_{\text{head}} = d_{\text{model}} / h$:

$$\text{head}_i = \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_{\text{head}}}}\right) V_i$$

$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\ldots,\text{head}_h)\,W^O$$

This problem adds a residual connection and layer norm following the original (post-LN) transformer convention:

$$\text{output} = \text{LayerNorm}(x + \text{MultiHead}(x))$$

Post-LN vs pre-LN: The original 2017 paper normalises after adding the residual (post-LN, used here). Many modern implementations normalise before the attention sub-layer (pre-LN, e.g. GPT-2) for training stability. Both conventions appear in production code, so it’s worth knowing which is which.

Pipeline

Given x of shape (N, T, d_model) and weight matrices w_q, w_k, w_v of shape (d_model, d_model) and w_o of shape (d_model, d_model):

  1. Project: Q = x @ w_q, K = x @ w_k, V = x @ w_v — each (N, T, d_model).
  2. Reshape: split last dim d_model(num_heads, d_head), transpose to (N, num_heads, T, d_head).
  3. Per-head scaled dot-product: scores = Q @ Kᵀ / sqrt(d_head) — shape (N, num_heads, T, T); attn = softmax(scores, dim=-1); per_head = attn @ V — shape (N, num_heads, T, d_head).
  4. Concat heads: transpose + reshape → (N, T, d_model).
  5. Output projection + residual: out = (concat @ w_o) + x.
  6. Layer norm over last dim with eps=1e-5 (no learned γ/β).

The weight matrices are passed in directly (not as nn.Linear modules) so that tests are reproducible and framework-agnostic.

Inputs / Output

  • x: shape (N, T, d_model) — input token sequence.
  • w_q, w_k, w_v: shape (d_model, d_model).
  • w_o: shape (d_model, d_model).
  • num_heads: int — must divide d_model evenly.
  • Output: shape (N, T, d_model).

Hints

attention transformer multi-head

Sign in to attempt this problem and view the solution.