We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Grouped-Query Attention (GQA)
Why this matters
The KV cache (pos 35) is the dominant memory cost at LLM inference
time. For LLaMA-2 70B with num_heads=64, head_dim=128, and a
4096-token context, the cache eats ~64 GB per sequence. At
batch_size=8 that’s half a TB just for K/V.
Grouped-Query Attention (GQA) reduces that cost without retraining
everything: have FEWER K/V heads than Q heads, with each K/V head
shared across a group of Q heads. LLaMA-2 70B uses num_kv_heads=8
instead of 64 — an 8x reduction in cache. Quality drop is small;
the cost savings are huge.
The split
Pick num_kv_heads such that num_heads % num_kv_heads == 0. The
group size is num_heads // num_kv_heads: that many Q heads share
one K/V head.
Projections:
-
q_proj:d_model -> num_heads * head_dim(all Q heads, full size) -
k_proj:d_model -> num_kv_heads * head_dim(FEWER K heads) -
v_proj:d_model -> num_kv_heads * head_dim(FEWER V heads) -
out_proj:num_heads * head_dim -> d_model
K and V are smaller — that’s the whole win. The cache holds these
smaller tensors, so memory drops by num_heads / num_kv_headsx.
Broadcasting K and V to match Q heads
SDPA still wants Q and K to have the SAME number of heads. So we
expand K and V by repeating each head group_size times:
repeat = num_heads // num_kv_heads
k = jnp.repeat(k, repeat, axis=0) # (num_kv_heads, T, Dh) -> (num_heads, T, Dh)
v = jnp.repeat(v, repeat, axis=0)
jnp.repeat(arr, n, axis=0) repeats along axis 0:
[A, B] -> [A, A, B, B] for n=2. After this expansion, every
Q head has a “copy” of its group’s K/V. The math is identical to
full MHA from this point.
Why repeat instead of tile?
jnp.tile([A, B], 2) -> [A, B, A, B] (full pattern repeats).
jnp.repeat([A, B], 2, axis=0) -> [A, A, B, B] (each element
repeats in place).
For GQA we want each KV head to align with group_size consecutive
Q heads — repeat is the right primitive.
Worked sketch
class GQAMHA(nnx.Module):
def __init__(self, d_model, num_heads, num_kv_heads, rngs):
assert d_model % num_heads == 0
assert num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = d_model // num_heads
self.q_proj = nnx.Linear(d_model, num_heads * self.head_dim, rngs=rngs)
self.k_proj = nnx.Linear(d_model, num_kv_heads * self.head_dim, rngs=rngs)
self.v_proj = nnx.Linear(d_model, num_kv_heads * self.head_dim, rngs=rngs)
self.out_proj = nnx.Linear(num_heads * self.head_dim, d_model, rngs=rngs)
def __call__(self, x):
T, _ = x.shape
H, Hkv, Dh = self.num_heads, self.num_kv_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, Hkv, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, Hkv, Dh).transpose(1, 0, 2)
repeat = H // Hkv
k = jnp.repeat(k, repeat, axis=0) # broadcast K to Q heads
v = jnp.repeat(v, repeat, axis=0)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
Why this saves cache memory
The repeat is a forward-pass op — it does not store anything new in
the cache. Cache holds the SMALL (T, num_kv_heads, head_dim) K
and V; the broadcast happens on read. So:
-
cache size:
2 * max_len * num_kv_heads * head_dim -
vs MHA:
2 * max_len * num_heads * head_dim
LLaMA-2 70B’s 8x reduction is num_heads/num_kv_heads = 64/8 = 8.
Common pitfalls
-
Repeat axis wrong. Repeat along the heads axis (axis 0 after
transpose).
axis=1would repeat per timestep, which is nonsense. -
jnp.tileinstead ofjnp.repeat. Tile interleaves; repeat is what you want. -
num_headsnot divisible bynum_kv_heads. Then groups don’t partition cleanly. Assert it. -
Repeat factor swapped.
H // Hkvisgroup_size.Hkv // His0whenHkv < H— would silently give an empty array.
Problem
Write mha_gqa(seed, x, num_heads, num_kv_heads, d_model):
-
Define
GQAMHA(nnx.Module)with Q/K/V projections of differing widths (Q is full, K/V arenum_kv_heads * head_dim), plusout_projof full width. -
__call__: project, reshape Q to(H, T, Dh)and K/V to(Hkv, T, Dh), repeat K/V along axis 0 byH // Hkvto match Q’s head count, SDPA, concat back,out_proj. - Cast all dimension args from float to int. Return flattened.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,num_kv_heads,d_model: ints (passed as floats).num_heads % num_kv_heads == 0andd_model % num_heads == 0.
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.