Implement Grouped-Query Attention (GQA) from “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (Ainslie et al., 2023).
GQA is a generalization between multi-head attention (MHA) and multi-query attention (MQA). Query heads are divided into groups, and each group shares one set of key-value heads.
Given:
Q: shape (seq_len, n_heads, d_k) — query with n_heads heads K: shape (seq_len, n_kv_heads, d_k) — key with n_kv_heads groups V: shape (seq_len, n_kv_heads, d_k) — value with n_kv_heads groups n_heads must be divisible by n_kv_heads
Head h uses KV group g = h // (n_heads // n_kv_heads):
$$\text{scores}_h = \frac{Q[:, h, :] \cdot K[:, g, :]^T}{\sqrt{d_k}}$$
$$\text{out}_h = \text{softmax}(\text{scores}_h) \cdot V[:, g, :]$$
Output: Tensor of shape (seq_len, n_heads, d_k).