hard primitives

NNX RoPE Attention

Why this matters

Rotary Positional Embedding (RoPE), introduced by Su et al. (2021), has become the de-facto positional encoding for modern LLMs: LLaMA, LLaMA-2, LLaMA-3, Qwen, Mistral, Falcon, GPT-NeoX. Instead of adding a position embedding to inputs (absolute) or biasing scores (ALiBi), RoPE rotates the Q and K vectors in 2-D feature pairs by an angle that depends on position.

The neat property: the dot product of RoPE(q_i) and RoPE(k_j) becomes a function of q_i, k_j, and (i - j) — RELATIVE position. Q and K each see only their own absolute position; their dot product encodes the relative offset implicitly. This is what makes RoPE both extrapolation-friendly and the favourite for long- context fine-tuning (“rope scaling”).

Note: RoPE is applied to Q and K, NOT to V. V keeps its content role; rotating Q/K affects only the score geometry.

The math

Pick base B (commonly 10000 in original RoPE; smaller bases like 500 give faster-rotating frequencies, useful for short contexts). Per dimension pair i ∈ [0, head_dim/2):

θ_i = B^(-2i / head_dim)

Per token position pos:

angle[pos, i] = pos * θ_i              # shape (T, head_dim/2)
cos[pos, i]   = cos(angle[pos, i])
sin[pos, i]   = sin(angle[pos, i])

The rotation

For a feature vector x of shape (..., T, head_dim), treat adjacent feature pairs (x[..., 0], x[..., 1]), (x[..., 2], x[..., 3]), … as 2-D points and rotate each pair:

(x1, x2) -> (x1 cos - x2 sin, x1 sin + x2 cos)

where cos and sin come from the position-and-pair tables above.

def rotate(x, cos, sin):
    x1 = x[..., 0::2]                  # even-indexed features
    x2 = x[..., 1::2]                  # odd-indexed
    rx1 = x1 * cos - x2 * sin
    rx2 = x1 * sin + x2 * cos
    rotated = jnp.stack([rx1, rx2], axis=-1)   # (..., T, head_dim/2, 2)
    return rotated.reshape(*x.shape)            # back to (..., T, head_dim)

Apply to both Q and K (same cos/sin), then run SDPA as usual. V is untouched.

Worked sketch

class RopeMHA(nnx.Module):
    # __init__: standard four nnx.Linear projections + base.

    def _cos_sin(self, T):
        i = jnp.arange(self.head_dim // 2)
        theta = jnp.power(self.base, -2.0 * i / self.head_dim)
        pos = jnp.arange(T)
        angles = pos[:, None] * theta[None, :]   # (T, head_dim/2)
        return jnp.cos(angles), jnp.sin(angles)

    def __call__(self, x):
        T, _ = x.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        cos, sin = self._cos_sin(T)
        cos = cos[None, :, :]                    # (1, T, Dh/2)
        sin = sin[None, :, :]
        q = rotate(q, cos, sin)
        k = rotate(k, cos, sin)                  # V is NOT rotated
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

Why it encodes relative position

Take the simplest case: head_dim = 2. Then RoPE(q, pos) is literally a rotation of q by pos * θ in the plane. The dot product:

RoPE(q, i) . RoPE(k, j)
  = q . R(-iθ) R(jθ) k        # rotations compose
  = q . R((j - i)θ) k

Pure function of (j - i). No extra knowledge of i or j individually leaks through — only their difference. Stack multiple frequencies (different θ_i) to get a richer relative encoding.

Common pitfalls

  • Rotating V. Don’t. RoPE is for Q/K only.
  • Wrong pair grouping. RoPE pairs adjacent features: (x[0], x[1]), (x[2], x[3]), …. Some implementations split the head into halves instead — (x[:Dh/2], x[Dh/2:]) — which is a different convention (often called “GPT-J style”). Stick to adjacent pairs here.
  • head_dim odd. RoPE needs even dim. Assert it.
  • Different cos/sin for Q vs K. They must use the same tables — that’s how the relative-position property works.
  • Forgetting to broadcast over heads. cos, sin shape (T, Dh/2), but q is (H, T, Dh). Add a leading axis: cos[None, :, :].

Problem

Write mha_rope(seed, x, num_heads, d_model, base):

  1. Define RopeMHA(nnx.Module): four nnx.Linear projections plus base as a plain attribute.
  2. Helper _cos_sin(T): builds (T, head_dim/2) cosine and sine tables from θ_i = base^(-2i/head_dim) and pos * θ_i.
  3. Helper rotate(x, cos, sin): pairs adjacent features (x[..., 0::2], x[..., 1::2]), applies the 2-D rotation, restacks and reshapes back.
  4. __call__: project Q/K/V, rotate Q and K (NOT V), SDPA, concat, out_proj.
  5. Cast num_heads, d_model to int. base stays float. Return flat.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model: ints (passed as floats). head_dim even.
  • base: float (e.g. 10000.0).

Output: 1-D flattened.

Hints

flax nnx attention rope positional transformers

Sign in to attempt this problem and view the solution.