hard primitives

Multi-Query Attention (MQA)

Why this matters

Multi-Query Attention (Shazeer 2019) is the most aggressive KV-cache compression: ALL H query heads share a SINGLE K/V head. The KV cache shrinks by — a 32-head model needs 1/32 the cache memory of standard MHA.

Used by PaLM, the original Falcon, and as the limit case of GQA where num_kv_heads = 1.

Trade-off: quality is noticeably worse than MHA at the same capacity. GQA (pos 26) was invented as the middle ground — most modern LLMs chose GQA over pure MQA.

What changes from GQA

Compared to GQA: same code, but num_kv_heads = 1.

  • Q: (T, H, d) — H heads.
  • K: (T, 1, d) — ONE head.
  • V: (T, 1, d) — ONE head.
  • Repeat K and V H times: jnp.repeat(k, H, axis=-2)(T, H, d).
  • SDPA per head as usual.

KV-cache impact at scale

For a 70B-param model with H=64, head_dim=128:

  • MHA cache per layer per token: 2 * 64 * 128 = 16,384 floats.
  • MQA cache: 2 * 1 * 128 = 256 floats (64× smaller).

At a 100k-token context across 80 layers, MHA cache is gigabytes; MQA squeezes it to tens of megabytes. That’s why MQA was invented — to make autoregressive serving feasible at scale.

Subtle point: parameter count

K and V projections shrink dramatically:

  • MHA K-proj params: D_in * H * head_dim.
  • MQA K-proj params: D_in * 1 * head_dim fewer params.

The Q and out projections are unchanged. Total parameter count drops a little, but the runtime impact (smaller KV cache, less memory bandwidth at decode) is the bigger win.

Custom Flax module

Same shape as the GQA module from pos 26, hardcoded to K=1:

class MQA(nn.Module):
    num_heads: int
    qkv_features: int

    @nn.compact
    def __call__(self, x):
        H = self.num_heads
        head_dim = self.qkv_features // H
        in_features = x.shape[-1]
        q = nn.DenseGeneral(features=(H, head_dim), axis=-1, name='q')(x)
        k = nn.DenseGeneral(features=(1, head_dim), axis=-1, name='k')(x)
        v = nn.DenseGeneral(features=(1, head_dim), axis=-1, name='v')(x)
        k = jnp.repeat(k, H, axis=-2)   # broadcast 1 -> H heads
        v = jnp.repeat(v, H, axis=-2)
        # ... usual per-head SDPA + out projection

Common pitfalls

  • Forgetting the repeat — without it, K is shape (T, 1, d) and can’t pair with Q’s H heads in the einsum. (Could broadcast in some cases but the moveaxis pattern requires explicit shape match.)
  • jnp.tile vs jnp.repeat — for K=1 they happen to give the same result, but using repeat keeps the GQA template consistent.

Problem

Implement mha_mqa(seed, x, num_heads, qkv_features) using a custom MQA Module (similar to GQA but hard-coded to one K/V head):

  1. Project Q with H heads, K and V with 1 head.
  2. jnp.repeat(k, H, axis=-2) and same for v to broadcast to H heads.
  3. Per-head SDPA as in GQA.
  4. Final out projection back to D_in.

Inputs:

  • seed: int.
  • x: 2-D (T, D_in).
  • num_heads: H.
  • qkv_features: D, divisible by H.

Output: 1-D, the flattened MQA output.

Hints

flax attention mqa transformers

Sign in to attempt this problem and view the solution.