medium research

Multi-Query Attention

Implement Multi-Query Attention (MQA) from “Fast Transformer Decoding” (Shazeer, 2019).

In MQA, queries have multiple heads but keys and values are shared across all heads. This reduces KV cache memory while maintaining model quality.

Given:

  • Q: shape (seq_len, n_heads, d_k) — per-head queries
  • K: shape (seq_len, d_k) — single shared key
  • V: shape (seq_len, d_k) — single shared value

For each head h: $$\text{scores}_h = \frac{Q[:, h, :] \cdot K^T}{\sqrt{d_k}}$$ $$\text{attn}_h = \text{softmax}(\text{scores}_h)$$ $$\text{out}_h = \text{attn}_h \cdot V$$

Output: Tensor of shape (seq_len, n_heads, d_k) — concatenated head outputs.

Hints

multi-query-attention mqa shazeer-2019 attention transformer
Detecting runtime...