hard research

Grouped-Query Attention

Implement Grouped-Query Attention (GQA) from “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (Ainslie et al., 2023).

GQA is a generalization between multi-head attention (MHA) and multi-query attention (MQA). Query heads are divided into groups, and each group shares one set of key-value heads.

Given:

  • Q: shape (seq_len, n_heads, d_k) — query with n_heads heads
  • K: shape (seq_len, n_kv_heads, d_k) — key with n_kv_heads groups
  • V: shape (seq_len, n_kv_heads, d_k) — value with n_kv_heads groups
  • n_heads must be divisible by n_kv_heads

Head h uses KV group g = h // (n_heads // n_kv_heads): $$\text{scores}_h = \frac{Q[:, h, :] \cdot K[:, g, :]^T}{\sqrt{d_k}}$$ $$\text{out}_h = \text{softmax}(\text{scores}_h) \cdot V[:, g, :]$$

Output: Tensor of shape (seq_len, n_heads, d_k).

Hints

grouped-query-attention gqa ainslie-2023 attention transformer
Detecting runtime...