We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Rotary Position Embedding (RoPE)
Why this matters
RoPE (Rotary Position Embedding, Su et al. 2021) is the position scheme behind LLaMA, GPT-NeoX, PaLM, and most modern LLMs. Unlike sinusoidal/learned encodings — which add a position vector to the embedding — RoPE rotates consecutive feature pairs by a position- dependent angle. Two key advantages:
-
Relative position emerges naturally: the dot product
RoPE(q, m) · RoPE(k, n)depends only onm - n, not on the absolute positions. This is exactly what attention should care about. - Length extrapolation: a model trained at 2k context can attend at 4k context (with quality degradation, fixed by NTK scaling).
RoPE is applied separately to Q and K inside the attention block, before the dot product. V is left alone.
The rotation, concretely
For a vector x of shape (d,) with even d, group consecutive pairs:
(x[0], x[1]), (x[2], x[3]), ..., (x[d-2], x[d-1]). For pair index
i ∈ [0, d/2) and position pos, define an angle:
θ_i = base^(-2i/d)
φ = pos · θ_i
Then rotate the pair (a, b) = (x[2i], x[2i+1]) by φ:
a' = a · cos(φ) - b · sin(φ)
b' = a · sin(φ) + b · cos(φ)
Stack the rotated pairs back into a vector of shape (d,). Apply this
independently per position pos.
base = 10000.0 is the LLaMA / standard choice. Larger base = lower
frequencies = handles longer contexts.
Vectorised implementation
For shape (T, d):
half = d // 2
inv_freq = 1.0 / base ** (jnp.arange(half) * 2 / d) # (d/2,)
pos = jnp.arange(T)[:, None] # (T, 1)
angles = pos * inv_freq[None, :] # (T, d/2)
cos = jnp.cos(angles)
sin = jnp.sin(angles)
x_even = x[:, 0::2] # (T, d/2)
x_odd = x[:, 1::2] # (T, d/2)
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
Then interleave out_even and out_odd back. With JAX’s .at[...].set:
out = jnp.zeros_like(x)
out = out.at[:, 0::2].set(out_even)
out = out.at[:, 1::2].set(out_odd)
Worked example
x = [1, 0, 0, 1] at position 1 with d=4, base=10000:
-
Pair 0:
(1, 0), angle1 · 10000^0 = 1rad. Rotation:(cos 1, sin 1) ≈ (0.540, 0.841). So(1, 0) → (0.540, 0.841). -
Pair 1:
(0, 1), angle1 · 10000^(-1) = 1e-4rad. Tiny rotation:(0, 1) → (-sin 1e-4, cos 1e-4) ≈ (-1e-4, 1.0).
Result: [0.540, 0.841, -1e-4, 1.0].
Common pitfalls
-
Wrong pair grouping: the original paper groups
(x[i], x[i + d/2])(split-style); LLaMA / Hugging Face implementations group(x[2i], x[2i+1])(interleaved). They differ by a permutation but both work if applied consistently to Q and K. We use the interleaved form here. -
Forgetting
baseis in the denominator:θ_i = base^(-2i/d), notbase^(2i/d). -
dodd: this scheme requiresd % 2 == 0. For oddd, RoPE is undefined. -
Mixing rotation matrix conventions:
[[cos, -sin], [sin, cos]]rotates counter-clockwise. Swapping signs gives clockwise rotation — still works as long as Q and K agree.
Problem
Implement rope_apply(seed, x, base):
-
T, d = x.shape.seedis unused. -
Build cos/sin arrays of shape
(T, d/2)usingθ_i = base^(-2i/d)andφ = pos · θ_i. -
Split
xinto(x_even, x_odd)along the last dim using[:, 0::2]and[:, 1::2]. -
Rotate:
out_even = x_even·cos - x_odd·sin,out_odd = x_even·sin + x_odd·cos. -
Interleave back into a
(T, d)array and return flattened.
Inputs:
-
seed: int (unused). -
x: 2-D float(T, d)withdeven. -
base: float — typically 10000.0.
Output: 1-D array of length T · d.
Hints
Sign in to attempt this problem and view the solution.