Implement Multi-Query Attention (MQA) from “Fast Transformer Decoding” (Shazeer, 2019).
In MQA, queries have multiple heads but keys and values are shared across all heads. This reduces KV cache memory while maintaining model quality.
Given:
Q: shape (seq_len, n_heads, d_k) — per-head queries K: shape (seq_len, d_k) — single shared key V: shape (seq_len, d_k) — single shared value For each head h: $$\text{scores}_h = \frac{Q[:, h, :] \cdot K^T}{\sqrt{d_k}}$$ $$\text{attn}_h = \text{softmax}(\text{scores}_h)$$ $$\text{out}_h = \text{attn}_h \cdot V$$
Output: Tensor of shape (seq_len, n_heads, d_k) — concatenated head outputs.