We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Query Attention (MQA)
Why this matters
Multi-Query Attention (Shazeer 2019) is the most aggressive KV-cache
compression: ALL H query heads share a SINGLE K/V head. The KV cache
shrinks by H× — a 32-head model needs 1/32 the cache memory of
standard MHA.
Used by PaLM, the original Falcon, and as the limit case of GQA where
num_kv_heads = 1.
Trade-off: quality is noticeably worse than MHA at the same capacity. GQA (pos 26) was invented as the middle ground — most modern LLMs chose GQA over pure MQA.
What changes from GQA
Compared to GQA: same code, but num_kv_heads = 1.
-
Q:
(T, H, d)— H heads. -
K:
(T, 1, d)— ONE head. -
V:
(T, 1, d)— ONE head. -
Repeat K and V
Htimes:jnp.repeat(k, H, axis=-2)→(T, H, d). - SDPA per head as usual.
KV-cache impact at scale
For a 70B-param model with H=64, head_dim=128:
-
MHA cache per layer per token:
2 * 64 * 128 = 16,384floats. -
MQA cache:
2 * 1 * 128 = 256floats (64× smaller).
At a 100k-token context across 80 layers, MHA cache is gigabytes; MQA squeezes it to tens of megabytes. That’s why MQA was invented — to make autoregressive serving feasible at scale.
Subtle point: parameter count
K and V projections shrink dramatically:
-
MHA K-proj params:
D_in * H * head_dim. -
MQA K-proj params:
D_in * 1 * head_dim—H×fewer params.
The Q and out projections are unchanged. Total parameter count drops a little, but the runtime impact (smaller KV cache, less memory bandwidth at decode) is the bigger win.
Custom Flax module
Same shape as the GQA module from pos 26, hardcoded to K=1:
class MQA(nn.Module):
num_heads: int
qkv_features: int
@nn.compact
def __call__(self, x):
H = self.num_heads
head_dim = self.qkv_features // H
in_features = x.shape[-1]
q = nn.DenseGeneral(features=(H, head_dim), axis=-1, name='q')(x)
k = nn.DenseGeneral(features=(1, head_dim), axis=-1, name='k')(x)
v = nn.DenseGeneral(features=(1, head_dim), axis=-1, name='v')(x)
k = jnp.repeat(k, H, axis=-2) # broadcast 1 -> H heads
v = jnp.repeat(v, H, axis=-2)
# ... usual per-head SDPA + out projection
Common pitfalls
-
Forgetting the repeat — without it, K is shape
(T, 1, d)and can’t pair with Q’s H heads in the einsum. (Could broadcast in some cases but the moveaxis pattern requires explicit shape match.) -
jnp.tilevsjnp.repeat— for K=1 they happen to give the same result, but using repeat keeps the GQA template consistent.
Problem
Implement mha_mqa(seed, x, num_heads, qkv_features) using a custom
MQA Module (similar to GQA but hard-coded to one K/V head):
-
Project Q with
Hheads, K and V with1head. -
jnp.repeat(k, H, axis=-2)and same for v to broadcast to H heads. - Per-head SDPA as in GQA.
-
Final out projection back to
D_in.
Inputs:
-
seed: int. -
x: 2-D(T, D_in). -
num_heads: H. -
qkv_features: D, divisible by H.
Output: 1-D, the flattened MQA output.
Hints
Sign in to attempt this problem and view the solution.