We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Multi-Query Attention (MQA)
Why this matters
Multi-Query Attention (MQA) is the extreme case of GQA:
num_kv_heads = 1. A single K head and a single V head shared
across every Q head. Maximum KV-cache compression — by a factor of
num_heads.
PaLM (Google) used MQA. Falcon and StarCoder do too. The trade-off: quality drops a touch more than GQA, since all Q heads now share one set of keys and values, but inference latency and memory footprint shrink dramatically.
Difference from GQA
GQA with num_kv_heads = 1 is MQA. You could literally re-use
the GQA implementation and pass num_kv_heads=1. We’re writing
a dedicated path here for clarity and because the broadcast is
even simpler — no need to repeat across groups, just broadcast
one head across all of Q.
Projections:
-
q_proj:d_model -> num_heads * head_dim(every Q head) -
k_proj:d_model -> head_dim(ONE K head, no head dim in output) -
v_proj:d_model -> head_dim(ONE V head) -
out_proj:num_heads * head_dim -> d_model
Broadcasting one K/V head
K and V come out as (T, head_dim) — no head axis at all. To
line up with q of shape (num_heads, T, head_dim), broadcast:
k = jnp.broadcast_to(k[None, :, :], (num_heads, T, head_dim))
v = jnp.broadcast_to(v[None, :, :], (num_heads, T, head_dim))
arr[None, :, :] adds a leading singleton axis: (1, T, Dh).
jnp.broadcast_to blows that up to (H, T, Dh) without copying
memory (just a view-style broadcast). Now SDPA proceeds normally.
Equivalent: jnp.repeat(k[None], num_heads, axis=0) — but that
actually allocates. broadcast_to is the cheap version.
Worked sketch
class MQAMHA(nnx.Module):
def __init__(self, d_model, num_heads, rngs):
assert d_model % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nnx.Linear(d_model, num_heads * self.head_dim, rngs=rngs)
self.k_proj = nnx.Linear(d_model, self.head_dim, rngs=rngs)
self.v_proj = nnx.Linear(d_model, self.head_dim, rngs=rngs)
self.out_proj = nnx.Linear(num_heads * self.head_dim, d_model, rngs=rngs)
def __call__(self, x):
T, _ = x.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x) # (T, Dh)
v = self.v_proj(x) # (T, Dh)
k = jnp.broadcast_to(k[None, :, :], (H, T, Dh))
v = jnp.broadcast_to(v[None, :, :], (H, T, Dh))
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
How much does this save?
With num_heads=64, head_dim=128, max_len=4096:
-
Full MHA cache:
2 * 4096 * 64 * 128 * 2 bytes (fp16) ≈ 134 MBper layer per sequence. -
MQA cache:
2 * 4096 * 1 * 128 * 2 bytes ≈ 2 MBper layer per sequence.
A 64x reduction. For a 32-layer model: 4 GB -> 64 MB per sequence. At batch 16 that’s 64 GB -> 1 GB.
The price is quality. MQA’s homogenized K/V tends to make the model “blur” different attention patterns; GQA strikes a middle ground (4 or 8 KV heads) that keeps most of the speedup with less quality loss.
Common pitfalls
-
Forgetting to add the singleton axis.
jnp.broadcast_to(k, (H, T, Dh))withk.shape == (T, Dh)raises a shape-mismatch error. You needk[None, :, :]first. -
Wrong head count in
out_proj. Q is full width (num_heads * head_dim), soout_proj‘s input dim is full width, nothead_dim. -
Reshaping K with a head dim. K projection is
d_model -> head_dimdirectly — there’s noHkv * Dhto split. -
Using
jnp.repeat. Works, but allocates.broadcast_tois view-only.
Problem
Write mha_mqa(seed, x, num_heads, d_model):
-
Define
MQAMHA(nnx.Module): Q full width, K and V of widthhead_dim(single head),out_projfull width. -
__call__: project, reshape Q to(H, T, Dh), broadcast K/V from(T, Dh)to(H, T, Dh), SDPA, concat,out_proj. -
Cast
num_heads, d_modelto int. Buildnnx.Rngs(int(seed)). Return flattened.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model: ints (passed as floats).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.