hard end_to_end

Transformer Block with RoPE

Implement a transformer block that applies Rotary Position Embeddings (RoPE) to Q and K before computing attention — the architecture used in LLaMA, GPT-NeoX, PaLM, and most modern open-weight LLMs.

Background: RoFormer (Su et al., 2021)

Classical absolute position encodings (sinusoidal, learned) add position info once at the embedding layer. RoPE takes a different approach: it injects relative-position information directly into the attention score by rotating Q and K in head-space.

The key insight is that if Q and K are rotated by angle m·θ and n·θ respectively (positions m and n), then Q_rot · K_rot^T depends only on the difference (m - n)·θ. This makes RoPE a relative position encoding disguised as an absolute one.

Why Q and K, but not V?

Position information enters the computation via the attention score Q · K^T. By rotating Q and K, relative distances are baked into which tokens attend to which. The value vectors V carry content, not position — rotating them would distort the information being aggregated, not the selection of what to aggregate.

RoPE Rotation Rule

For a query vector at position t, head h, and frequency-pair index k:

q'[..., t, h, 2k]   = q[..., t, h, 2k]   * cos[t, k]  -  q[..., t, h, 2k+1] * sin[t, k]
q'[..., t, h, 2k+1] = q[..., t, h, 2k]   * sin[t, k]  +  q[..., t, h, 2k+1] * cos[t, k]

The same rule applies to K. The freqs_cos and freqs_sin tables (shape (T, d_head/2)) are precomputed and passed as arguments — you do not need to derive frequencies from scratch.

Full Pipeline

  1. Project: Q = x @ w_q, K = x @ w_k, V = x @ w_v — each (N, T, d_model).
  2. Multi-head reshape: split d_model(num_heads, d_head), transpose to (N, num_heads, T, d_head).
  3. Apply RoPE to Q and K using even/odd index slicing (see Hints). V is left unchanged.
  4. Scaled dot-product: scores = Q_rot @ K_rot^T / sqrt(d_head). No causal mask (this is a bidirectional encoder-style block).
  5. Softmax + weighted sum: attn = softmax(scores, dim=-1); out_heads = attn @ V.
  6. Concat + output projection + residual: out = (concat @ w_o) + x.
  7. Layer norm over last dim with eps=1e-5 (no learned γ/β).

Inputs / Output

  • x: (N, T, d_model).
  • w_q, w_k, w_v, w_o: (d_model, d_model).
  • num_heads: d_model divisible by num_heads; d_head = d_model / num_heads must be even (RoPE pairs consecutive dims).
  • freqs_cos: (T, d_head/2) — precomputed cosines.
  • freqs_sin: (T, d_head/2) — precomputed sines.
  • Output: (N, T, d_model).

Hints

attention transformer rope position-encoding

Sign in to attempt this problem and view the solution.