medium end_to_end

Causal Self-Attention Block

Implement causal (masked) multi-head self-attention with a residual connection and layer norm β€” the core building block of GPT-style decoder transformers.

Causal vs Bi-directional Attention

In an encoder (e.g. BERT) every token can attend to every other token. In a decoder (e.g. GPT) each token may only attend to itself and earlier tokens β€” otherwise the model could β€œcheat” by peeking at future tokens during training. This constraint is enforced with a causal mask: before applying softmax, set all scores at positions (i, j) where j > i to -1e9 (a large negative number that becomes ~0 after softmax).

Why -1e9 and not -inf? Literal -inf can produce NaN gradients in some frameworks (e.g. when an entire row is masked and the gradient through softmax touches 0/0). Using -1e9 is numerically safe and ubiquitous in production code.

Pipeline

Given x of shape (N, T, d_model) and weight matrices w_q, w_k, w_v, w_o each of shape (d_model, d_model):

  1. Project: Q = x @ w_q, K = x @ w_k, V = x @ w_v β€” each (N, T, d_model).
  2. Reshape: split last dim d_model β†’ (num_heads, d_head), transpose to (N, num_heads, T, d_head).
  3. Scaled dot-product scores: scores = Q @ Kα΅€ / sqrt(d_head) β€” shape (N, num_heads, T, T).
  4. Causal mask: build a (T, T) lower-triangular matrix of ones (torch.tril), then masked_fill the upper triangle (where mask == 0) with -1e9.
  5. Softmax + weighted sum: attn = softmax(scores, dim=-1); per_head = attn @ V β€” shape (N, num_heads, T, d_head).
  6. Concat heads: transpose + reshape β†’ (N, T, d_model).
  7. Output projection + residual: out = (concat @ w_o) + x.
  8. Layer norm over last dim with eps=1e-5 (no learned Ξ³/Ξ²).

GPT Context

This block is the heart of every GPT-style model. Stack it with token embeddings, position encodings, and a feed-forward sublayer and you have a full transformer decoder layer.

The weight matrices are passed in directly (not as nn.Linear modules) so that tests are reproducible and framework-agnostic.

Inputs / Output

  • x: shape (N, T, d_model) β€” input token sequence.
  • w_q, w_k, w_v: shape (d_model, d_model).
  • w_o: shape (d_model, d_model).
  • num_heads: int β€” must divide d_model evenly.
  • Output: shape (N, T, d_model).

Hints

attention transformer causal decoder

Sign in to attempt this problem and view the solution.