hard primitives

NNX Grouped-Query Attention (GQA)

Why this matters

The KV cache (pos 35) is the dominant memory cost at LLM inference time. For LLaMA-2 70B with num_heads=64, head_dim=128, and a 4096-token context, the cache eats ~64 GB per sequence. At batch_size=8 that’s half a TB just for K/V.

Grouped-Query Attention (GQA) reduces that cost without retraining everything: have FEWER K/V heads than Q heads, with each K/V head shared across a group of Q heads. LLaMA-2 70B uses num_kv_heads=8 instead of 64 — an 8x reduction in cache. Quality drop is small; the cost savings are huge.

The split

Pick num_kv_heads such that num_heads % num_kv_heads == 0. The group size is num_heads // num_kv_heads: that many Q heads share one K/V head.

Projections:

  • q_proj: d_model -> num_heads * head_dim (all Q heads, full size)
  • k_proj: d_model -> num_kv_heads * head_dim (FEWER K heads)
  • v_proj: d_model -> num_kv_heads * head_dim (FEWER V heads)
  • out_proj: num_heads * head_dim -> d_model

K and V are smaller — that’s the whole win. The cache holds these smaller tensors, so memory drops by num_heads / num_kv_headsx.

Broadcasting K and V to match Q heads

SDPA still wants Q and K to have the SAME number of heads. So we expand K and V by repeating each head group_size times:

repeat = num_heads // num_kv_heads
k = jnp.repeat(k, repeat, axis=0)   # (num_kv_heads, T, Dh) -> (num_heads, T, Dh)
v = jnp.repeat(v, repeat, axis=0)

jnp.repeat(arr, n, axis=0) repeats along axis 0: [A, B] -> [A, A, B, B] for n=2. After this expansion, every Q head has a “copy” of its group’s K/V. The math is identical to full MHA from this point.

Why repeat instead of tile?

jnp.tile([A, B], 2) -> [A, B, A, B] (full pattern repeats). jnp.repeat([A, B], 2, axis=0) -> [A, A, B, B] (each element repeats in place).

For GQA we want each KV head to align with group_size consecutive Q heads — repeat is the right primitive.

Worked sketch

class GQAMHA(nnx.Module):
    def __init__(self, d_model, num_heads, num_kv_heads, rngs):
        assert d_model % num_heads == 0
        assert num_heads % num_kv_heads == 0
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_model // num_heads
        self.q_proj = nnx.Linear(d_model, num_heads * self.head_dim, rngs=rngs)
        self.k_proj = nnx.Linear(d_model, num_kv_heads * self.head_dim, rngs=rngs)
        self.v_proj = nnx.Linear(d_model, num_kv_heads * self.head_dim, rngs=rngs)
        self.out_proj = nnx.Linear(num_heads * self.head_dim, d_model, rngs=rngs)

    def __call__(self, x):
        T, _ = x.shape
        H, Hkv, Dh = self.num_heads, self.num_kv_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, Hkv, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, Hkv, Dh).transpose(1, 0, 2)
        repeat = H // Hkv
        k = jnp.repeat(k, repeat, axis=0)        # broadcast K to Q heads
        v = jnp.repeat(v, repeat, axis=0)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

Why this saves cache memory

The repeat is a forward-pass op — it does not store anything new in the cache. Cache holds the SMALL (T, num_kv_heads, head_dim) K and V; the broadcast happens on read. So:

  • cache size: 2 * max_len * num_kv_heads * head_dim
  • vs MHA: 2 * max_len * num_heads * head_dim

LLaMA-2 70B’s 8x reduction is num_heads/num_kv_heads = 64/8 = 8.

Common pitfalls

  • Repeat axis wrong. Repeat along the heads axis (axis 0 after transpose). axis=1 would repeat per timestep, which is nonsense.
  • jnp.tile instead of jnp.repeat. Tile interleaves; repeat is what you want.
  • num_heads not divisible by num_kv_heads. Then groups don’t partition cleanly. Assert it.
  • Repeat factor swapped. H // Hkv is group_size. Hkv // H is 0 when Hkv < H — would silently give an empty array.

Problem

Write mha_gqa(seed, x, num_heads, num_kv_heads, d_model):

  1. Define GQAMHA(nnx.Module) with Q/K/V projections of differing widths (Q is full, K/V are num_kv_heads * head_dim), plus out_proj of full width.
  2. __call__: project, reshape Q to (H, T, Dh) and K/V to (Hkv, T, Dh), repeat K/V along axis 0 by H // Hkv to match Q’s head count, SDPA, concat back, out_proj.
  3. Cast all dimension args from float to int. Return flattened.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, num_kv_heads, d_model: ints (passed as floats). num_heads % num_kv_heads == 0 and d_model % num_heads == 0.

Output: 1-D flattened.

Hints

flax nnx attention gqa transformers inference

Sign in to attempt this problem and view the solution.