hard primitives

Grouped-Query Attention (GQA)

Why this matters

Standard MHA: each head has its own Q, K, V projection. With H=64 heads in a 70B-param model, the KV cache becomes ENORMOUS — at long contexts it dominates GPU memory.

Multi-Query Attention (MQA, Shazeer 2019) shares ONE K/V across all H query heads. Cache shrinks H×. But quality drops.

Grouped-Query Attention (GQA, Ainslie et al. 2023) is the compromise: K < H K/V heads, each shared by H/K query heads. LLaMA-2 70B uses H=64, K=8 — 8× cache reduction with negligible quality loss. Mistral, Falcon, GPT-4 Turbo all use GQA now.

The shapes

Standard MHA (num_heads=H):

  • Q: (T, H, d)
  • K: (T, H, d)
  • V: (T, H, d)
  • KV cache: 2 * T * H * d floats per layer.

GQA (num_heads=H, num_kv_heads=K, K < H):

  • Q: (T, H, d) — full
  • K: (T, K, d) — only K heads
  • V: (T, K, d) — only K heads
  • KV cache: 2 * T * K * d floats — H/K× smaller.

Each K/V head is shared by H/K query heads. Concretely, you compute K and V with K heads, then repeat them H/K times along the head axis to broadcast against Q.

A custom Flax module

flax.linen.MultiHeadDotProductAttention in this codebase doesn’t take num_kv_heads (only the newer flax.nnx.MultiHeadAttention does). So you’ll build a small custom module:

class GQA(nn.Module):
    num_heads: int       # H
    num_kv_heads: int    # K (must divide H)
    qkv_features: int    # D, head_dim = D // H

    @nn.compact
    def __call__(self, x):
        H, K, D = self.num_heads, self.num_kv_heads, self.qkv_features
        d = D // H
        in_features = x.shape[-1]
        q = nn.DenseGeneral(features=(H, d), axis=-1, name='q')(x)
        k = nn.DenseGeneral(features=(K, d), axis=-1, name='k')(x)
        v = nn.DenseGeneral(features=(K, d), axis=-1, name='v')(x)

        # Repeat K and V along the head axis to match H query heads.
        repeats = H // K
        k = jnp.repeat(k, repeats, axis=-2)   # (T, H, d)
        v = jnp.repeat(v, repeats, axis=-2)

        # Per-head scaled dot-product attention.
        q_h = jnp.moveaxis(q, -2, 0)   # (H, T, d)
        k_h = jnp.moveaxis(k, -2, 0)
        v_h = jnp.moveaxis(v, -2, 0)
        scores = jnp.einsum('hqd,hkd->hqk', q_h, k_h) / jnp.sqrt(d)
        weights = jax.nn.softmax(scores, axis=-1)
        out = jnp.einsum('hqk,hkd->hqd', weights, v_h)
        out = jnp.moveaxis(out, 0, -2)  # (T, H, d)

        # Final out projection from (T, H, d) back to (T, in_features).
        return nn.DenseGeneral(features=in_features, axis=(-2, -1), name='out')(out)

Why “repeat” works

Conceptually, each K/V head is shared by H/K query heads. After jnp.repeat(k, H/K, axis=-2), K’s effective shape matches Q’s, so standard per-head SDPA proceeds as if they were independent. The parameter count is what’s reduced — at runtime each shared K head just answers H/K query heads with the same K vector.

jnp.repeat(x, n, axis) differs from jnp.tile: repeat duplicates each element n times in place; tile duplicates the whole array. For GQA we want repeat — head 0 of K is shared by query heads 0..H/K-1.

Constraints

  • H % K == 0 — H must be divisible by K.
  • D % H == 0 — total Q dim must split evenly.

Common LLM choices: H=32, K=4 (LLaMA-3 8B), H=64, K=8 (LLaMA-2 70B).

Common pitfalls

  • jnp.tile instead of jnp.repeat: tile repeats blocks, repeat repeats elements. Wrong order means wrong head-to-head pairing.
  • Wrong axis for repeat: axis=-2 is the head axis after DenseGeneral outputs (T, K, d).
  • Forgetting the final out projection — without it, the output is per-head, dim (T, H, d), not the model’s expected (T, D_in).

Problem

Implement mha_gqa(seed, x, num_heads, num_kv_heads, qkv_features) using a custom GQA Module as outlined above:

  1. Build the module with the three configs.
  2. Init with PRNGKey(seed) on x.
  3. Apply on x, return out.reshape(-1).

All test cases use H=4, K=2 so H/K = 2 repeats.

Inputs:

  • seed: int.
  • x: 2-D (T, D_in).
  • num_heads: H.
  • num_kv_heads: K, divides H.
  • qkv_features: D, divisible by H.

Output: 1-D, the flattened (T, D_in) GQA output.

Hints

flax attention gqa transformers

Sign in to attempt this problem and view the solution.