We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX RoPE Attention
Why this matters
Rotary Positional Embedding (RoPE), introduced by Su et al. (2021), has become the de-facto positional encoding for modern LLMs: LLaMA, LLaMA-2, LLaMA-3, Qwen, Mistral, Falcon, GPT-NeoX. Instead of adding a position embedding to inputs (absolute) or biasing scores (ALiBi), RoPE rotates the Q and K vectors in 2-D feature pairs by an angle that depends on position.
The neat property: the dot product of RoPE(q_i) and RoPE(k_j)
becomes a function of q_i, k_j, and (i - j) — RELATIVE position.
Q and K each see only their own absolute position; their dot
product encodes the relative offset implicitly. This is what makes
RoPE both extrapolation-friendly and the favourite for long-
context fine-tuning (“rope scaling”).
Note: RoPE is applied to Q and K, NOT to V. V keeps its content role; rotating Q/K affects only the score geometry.
The math
Pick base B (commonly 10000 in original RoPE; smaller bases like
500 give faster-rotating frequencies, useful for short contexts).
Per dimension pair i ∈ [0, head_dim/2):
θ_i = B^(-2i / head_dim)
Per token position pos:
angle[pos, i] = pos * θ_i # shape (T, head_dim/2)
cos[pos, i] = cos(angle[pos, i])
sin[pos, i] = sin(angle[pos, i])
The rotation
For a feature vector x of shape (..., T, head_dim), treat
adjacent feature pairs (x[..., 0], x[..., 1]),
(x[..., 2], x[..., 3]), … as 2-D points and rotate each pair:
(x1, x2) -> (x1 cos - x2 sin, x1 sin + x2 cos)
where cos and sin come from the position-and-pair tables above.
def rotate(x, cos, sin):
x1 = x[..., 0::2] # even-indexed features
x2 = x[..., 1::2] # odd-indexed
rx1 = x1 * cos - x2 * sin
rx2 = x1 * sin + x2 * cos
rotated = jnp.stack([rx1, rx2], axis=-1) # (..., T, head_dim/2, 2)
return rotated.reshape(*x.shape) # back to (..., T, head_dim)
Apply to both Q and K (same cos/sin), then run SDPA as usual.
V is untouched.
Worked sketch
class RopeMHA(nnx.Module):
# __init__: standard four nnx.Linear projections + base.
def _cos_sin(self, T):
i = jnp.arange(self.head_dim // 2)
theta = jnp.power(self.base, -2.0 * i / self.head_dim)
pos = jnp.arange(T)
angles = pos[:, None] * theta[None, :] # (T, head_dim/2)
return jnp.cos(angles), jnp.sin(angles)
def __call__(self, x):
T, _ = x.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
cos, sin = self._cos_sin(T)
cos = cos[None, :, :] # (1, T, Dh/2)
sin = sin[None, :, :]
q = rotate(q, cos, sin)
k = rotate(k, cos, sin) # V is NOT rotated
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
Why it encodes relative position
Take the simplest case: head_dim = 2. Then RoPE(q, pos) is
literally a rotation of q by pos * θ in the plane. The dot
product:
RoPE(q, i) . RoPE(k, j)
= q . R(-iθ) R(jθ) k # rotations compose
= q . R((j - i)θ) k
Pure function of (j - i). No extra knowledge of i or j
individually leaks through — only their difference. Stack
multiple frequencies (different θ_i) to get a richer relative
encoding.
Common pitfalls
- Rotating V. Don’t. RoPE is for Q/K only.
-
Wrong pair grouping. RoPE pairs adjacent features:
(x[0], x[1]),(x[2], x[3]), …. Some implementations split the head into halves instead —(x[:Dh/2], x[Dh/2:])— which is a different convention (often called “GPT-J style”). Stick to adjacent pairs here. -
head_dimodd. RoPE needs even dim. Assert it. - Different cos/sin for Q vs K. They must use the same tables — that’s how the relative-position property works.
-
Forgetting to broadcast over heads.
cos, sinshape(T, Dh/2), butqis(H, T, Dh). Add a leading axis:cos[None, :, :].
Problem
Write mha_rope(seed, x, num_heads, d_model, base):
-
Define
RopeMHA(nnx.Module): fournnx.Linearprojections plusbaseas a plain attribute. -
Helper
_cos_sin(T): builds(T, head_dim/2)cosine and sine tables fromθ_i = base^(-2i/head_dim)andpos * θ_i. -
Helper
rotate(x, cos, sin): pairs adjacent features(x[..., 0::2], x[..., 1::2]), applies the 2-D rotation, restacks and reshapes back. -
__call__: project Q/K/V, rotate Q and K (NOT V), SDPA, concat,out_proj. -
Cast
num_heads, d_modelto int.basestays float. Return flat.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model: ints (passed as floats).head_dimeven. -
base: float (e.g. 10000.0).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.