hard primitives

NNX Multi-Query Attention (MQA)

Why this matters

Multi-Query Attention (MQA) is the extreme case of GQA: num_kv_heads = 1. A single K head and a single V head shared across every Q head. Maximum KV-cache compression — by a factor of num_heads.

PaLM (Google) used MQA. Falcon and StarCoder do too. The trade-off: quality drops a touch more than GQA, since all Q heads now share one set of keys and values, but inference latency and memory footprint shrink dramatically.

Difference from GQA

GQA with num_kv_heads = 1 is MQA. You could literally re-use the GQA implementation and pass num_kv_heads=1. We’re writing a dedicated path here for clarity and because the broadcast is even simpler — no need to repeat across groups, just broadcast one head across all of Q.

Projections:

  • q_proj: d_model -> num_heads * head_dim (every Q head)
  • k_proj: d_model -> head_dim (ONE K head, no head dim in output)
  • v_proj: d_model -> head_dim (ONE V head)
  • out_proj: num_heads * head_dim -> d_model

Broadcasting one K/V head

K and V come out as (T, head_dim) — no head axis at all. To line up with q of shape (num_heads, T, head_dim), broadcast:

k = jnp.broadcast_to(k[None, :, :], (num_heads, T, head_dim))
v = jnp.broadcast_to(v[None, :, :], (num_heads, T, head_dim))

arr[None, :, :] adds a leading singleton axis: (1, T, Dh). jnp.broadcast_to blows that up to (H, T, Dh) without copying memory (just a view-style broadcast). Now SDPA proceeds normally.

Equivalent: jnp.repeat(k[None], num_heads, axis=0) — but that actually allocates. broadcast_to is the cheap version.

Worked sketch

class MQAMHA(nnx.Module):
    def __init__(self, d_model, num_heads, rngs):
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.q_proj = nnx.Linear(d_model, num_heads * self.head_dim, rngs=rngs)
        self.k_proj = nnx.Linear(d_model, self.head_dim, rngs=rngs)
        self.v_proj = nnx.Linear(d_model, self.head_dim, rngs=rngs)
        self.out_proj = nnx.Linear(num_heads * self.head_dim, d_model, rngs=rngs)

    def __call__(self, x):
        T, _ = x.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x)                       # (T, Dh)
        v = self.v_proj(x)                       # (T, Dh)
        k = jnp.broadcast_to(k[None, :, :], (H, T, Dh))
        v = jnp.broadcast_to(v[None, :, :], (H, T, Dh))
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

How much does this save?

With num_heads=64, head_dim=128, max_len=4096:

  • Full MHA cache: 2 * 4096 * 64 * 128 * 2 bytes (fp16) ≈ 134 MB per layer per sequence.
  • MQA cache: 2 * 4096 * 1 * 128 * 2 bytes ≈ 2 MB per layer per sequence.

A 64x reduction. For a 32-layer model: 4 GB -> 64 MB per sequence. At batch 16 that’s 64 GB -> 1 GB.

The price is quality. MQA’s homogenized K/V tends to make the model “blur” different attention patterns; GQA strikes a middle ground (4 or 8 KV heads) that keeps most of the speedup with less quality loss.

Common pitfalls

  • Forgetting to add the singleton axis. jnp.broadcast_to(k, (H, T, Dh)) with k.shape == (T, Dh) raises a shape-mismatch error. You need k[None, :, :] first.
  • Wrong head count in out_proj. Q is full width (num_heads * head_dim), so out_proj‘s input dim is full width, not head_dim.
  • Reshaping K with a head dim. K projection is d_model -> head_dim directly — there’s no Hkv * Dh to split.
  • Using jnp.repeat. Works, but allocates. broadcast_to is view-only.

Problem

Write mha_mqa(seed, x, num_heads, d_model):

  1. Define MQAMHA(nnx.Module): Q full width, K and V of width head_dim (single head), out_proj full width.
  2. __call__: project, reshape Q to (H, T, Dh), broadcast K/V from (T, Dh) to (H, T, Dh), SDPA, concat, out_proj.
  3. Cast num_heads, d_model to int. Build nnx.Rngs(int(seed)). Return flattened.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model: ints (passed as floats).

Output: 1-D flattened.

Hints

flax nnx attention mqa transformers inference

Sign in to attempt this problem and view the solution.