hard primitives

Rotary Position Embedding (RoPE)

Why this matters

RoPE (Rotary Position Embedding, Su et al. 2021) is the position scheme behind LLaMA, GPT-NeoX, PaLM, and most modern LLMs. Unlike sinusoidal/learned encodings — which add a position vector to the embedding — RoPE rotates consecutive feature pairs by a position- dependent angle. Two key advantages:

  1. Relative position emerges naturally: the dot product RoPE(q, m) · RoPE(k, n) depends only on m - n, not on the absolute positions. This is exactly what attention should care about.
  2. Length extrapolation: a model trained at 2k context can attend at 4k context (with quality degradation, fixed by NTK scaling).

RoPE is applied separately to Q and K inside the attention block, before the dot product. V is left alone.

The rotation, concretely

For a vector x of shape (d,) with even d, group consecutive pairs: (x[0], x[1]), (x[2], x[3]), ..., (x[d-2], x[d-1]). For pair index i ∈ [0, d/2) and position pos, define an angle:

θ_i  = base^(-2i/d)
φ    = pos · θ_i

Then rotate the pair (a, b) = (x[2i], x[2i+1]) by φ:

a' = a · cos(φ) - b · sin(φ)
b' = a · sin(φ) + b · cos(φ)

Stack the rotated pairs back into a vector of shape (d,). Apply this independently per position pos.

base = 10000.0 is the LLaMA / standard choice. Larger base = lower frequencies = handles longer contexts.

Vectorised implementation

For shape (T, d):

half = d // 2
inv_freq = 1.0 / base ** (jnp.arange(half) * 2 / d)   # (d/2,)
pos = jnp.arange(T)[:, None]                           # (T, 1)
angles = pos * inv_freq[None, :]                       # (T, d/2)
cos = jnp.cos(angles)
sin = jnp.sin(angles)

x_even = x[:, 0::2]                                    # (T, d/2)
x_odd  = x[:, 1::2]                                    # (T, d/2)
out_even = x_even * cos - x_odd * sin
out_odd  = x_even * sin + x_odd * cos

Then interleave out_even and out_odd back. With JAX’s .at[...].set:

out = jnp.zeros_like(x)
out = out.at[:, 0::2].set(out_even)
out = out.at[:, 1::2].set(out_odd)

Worked example

x = [1, 0, 0, 1] at position 1 with d=4, base=10000:

  • Pair 0: (1, 0), angle 1 · 10000^0 = 1 rad. Rotation: (cos 1, sin 1) ≈ (0.540, 0.841). So (1, 0) → (0.540, 0.841).
  • Pair 1: (0, 1), angle 1 · 10000^(-1) = 1e-4 rad. Tiny rotation: (0, 1) → (-sin 1e-4, cos 1e-4) ≈ (-1e-4, 1.0).

Result: [0.540, 0.841, -1e-4, 1.0].

Common pitfalls

  • Wrong pair grouping: the original paper groups (x[i], x[i + d/2]) (split-style); LLaMA / Hugging Face implementations group (x[2i], x[2i+1]) (interleaved). They differ by a permutation but both work if applied consistently to Q and K. We use the interleaved form here.
  • Forgetting base is in the denominator: θ_i = base^(-2i/d), not base^(2i/d).
  • d odd: this scheme requires d % 2 == 0. For odd d, RoPE is undefined.
  • Mixing rotation matrix conventions: [[cos, -sin], [sin, cos]] rotates counter-clockwise. Swapping signs gives clockwise rotation — still works as long as Q and K agree.

Problem

Implement rope_apply(seed, x, base):

  1. T, d = x.shape. seed is unused.
  2. Build cos/sin arrays of shape (T, d/2) using θ_i = base^(-2i/d) and φ = pos · θ_i.
  3. Split x into (x_even, x_odd) along the last dim using [:, 0::2] and [:, 1::2].
  4. Rotate: out_even = x_even·cos - x_odd·sin, out_odd = x_even·sin + x_odd·cos.
  5. Interleave back into a (T, d) array and return flattened.

Inputs:

  • seed: int (unused).
  • x: 2-D float (T, d) with d even.
  • base: float — typically 10000.0.

Output: 1-D array of length T · d.

Hints

flax rope position-encoding transformers

Sign in to attempt this problem and view the solution.