We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Grouped-Query Attention
Implement Grouped-Query Attention (GQA) from “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” (Ainslie et al., 2023).
GQA is a generalization between multi-head attention (MHA) and multi-query attention (MQA). Query heads are divided into groups, and each group shares one set of key-value heads.
Given:
-
Q: shape(seq_len, n_heads, d_k)— query with n_heads heads -
K: shape(seq_len, n_kv_heads, d_k)— key with n_kv_heads groups -
V: shape(seq_len, n_kv_heads, d_k)— value with n_kv_heads groups -
n_headsmust be divisible byn_kv_heads
Head h uses KV group g = h // (n_heads // n_kv_heads):
$$\text{scores}_h = \frac{Q[:, h, :] \cdot K[:, g, :]^T}{\sqrt{d_k}}$$
$$\text{out}_h = \text{softmax}(\text{scores}_h) \cdot V[:, g, :]$$
Output: Tensor of shape (seq_len, n_heads, d_k).
Hints
Sign in to attempt this problem and view the solution.