We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Grouped-Query Attention (GQA)
Why this matters
Standard MHA: each head has its own Q, K, V projection. With H=64
heads in a 70B-param model, the KV cache becomes ENORMOUS — at long
contexts it dominates GPU memory.
Multi-Query Attention (MQA, Shazeer 2019) shares ONE K/V across all H query heads. Cache shrinks H×. But quality drops.
Grouped-Query Attention (GQA, Ainslie et al. 2023) is the
compromise: K < H K/V heads, each shared by H/K query heads.
LLaMA-2 70B uses H=64, K=8 — 8× cache reduction with negligible
quality loss. Mistral, Falcon, GPT-4 Turbo all use GQA now.
The shapes
Standard MHA (num_heads=H):
-
Q:
(T, H, d) -
K:
(T, H, d) -
V:
(T, H, d) -
KV cache:
2 * T * H * dfloats per layer.
GQA (num_heads=H, num_kv_heads=K, K < H):
-
Q:
(T, H, d)— full -
K:
(T, K, d)— only K heads -
V:
(T, K, d)— only K heads -
KV cache:
2 * T * K * dfloats —H/K× smaller.
Each K/V head is shared by H/K query heads. Concretely, you compute K
and V with K heads, then repeat them H/K times along the head
axis to broadcast against Q.
A custom Flax module
flax.linen.MultiHeadDotProductAttention in this codebase doesn’t take
num_kv_heads (only the newer flax.nnx.MultiHeadAttention does).
So you’ll build a small custom module:
class GQA(nn.Module):
num_heads: int # H
num_kv_heads: int # K (must divide H)
qkv_features: int # D, head_dim = D // H
@nn.compact
def __call__(self, x):
H, K, D = self.num_heads, self.num_kv_heads, self.qkv_features
d = D // H
in_features = x.shape[-1]
q = nn.DenseGeneral(features=(H, d), axis=-1, name='q')(x)
k = nn.DenseGeneral(features=(K, d), axis=-1, name='k')(x)
v = nn.DenseGeneral(features=(K, d), axis=-1, name='v')(x)
# Repeat K and V along the head axis to match H query heads.
repeats = H // K
k = jnp.repeat(k, repeats, axis=-2) # (T, H, d)
v = jnp.repeat(v, repeats, axis=-2)
# Per-head scaled dot-product attention.
q_h = jnp.moveaxis(q, -2, 0) # (H, T, d)
k_h = jnp.moveaxis(k, -2, 0)
v_h = jnp.moveaxis(v, -2, 0)
scores = jnp.einsum('hqd,hkd->hqk', q_h, k_h) / jnp.sqrt(d)
weights = jax.nn.softmax(scores, axis=-1)
out = jnp.einsum('hqk,hkd->hqd', weights, v_h)
out = jnp.moveaxis(out, 0, -2) # (T, H, d)
# Final out projection from (T, H, d) back to (T, in_features).
return nn.DenseGeneral(features=in_features, axis=(-2, -1), name='out')(out)
Why “repeat” works
Conceptually, each K/V head is shared by H/K query heads. After
jnp.repeat(k, H/K, axis=-2), K’s effective shape matches Q’s, so
standard per-head SDPA proceeds as if they were independent. The
parameter count is what’s reduced — at runtime each shared K head
just answers H/K query heads with the same K vector.
jnp.repeat(x, n, axis) differs from jnp.tile: repeat duplicates
each element n times in place; tile duplicates the whole array. For
GQA we want repeat — head 0 of K is shared by query heads 0..H/K-1.
Constraints
-
H % K == 0— H must be divisible by K. -
D % H == 0— total Q dim must split evenly.
Common LLM choices: H=32, K=4 (LLaMA-3 8B), H=64, K=8 (LLaMA-2 70B).
Common pitfalls
-
jnp.tileinstead ofjnp.repeat: tile repeats blocks, repeat repeats elements. Wrong order means wrong head-to-head pairing. -
Wrong
axisfor repeat:axis=-2is the head axis afterDenseGeneraloutputs(T, K, d). -
Forgetting the final out projection — without it, the output
is per-head, dim
(T, H, d), not the model’s expected(T, D_in).
Problem
Implement mha_gqa(seed, x, num_heads, num_kv_heads, qkv_features)
using a custom GQA Module as outlined above:
- Build the module with the three configs.
-
Init with PRNGKey(seed) on
x. -
Apply on
x, returnout.reshape(-1).
All test cases use H=4, K=2 so H/K = 2 repeats.
Inputs:
-
seed: int. -
x: 2-D(T, D_in). -
num_heads: H. -
num_kv_heads: K, divides H. -
qkv_features: D, divisible by H.
Output: 1-D, the flattened (T, D_in) GQA output.
Hints
Sign in to attempt this problem and view the solution.