We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX ALiBi Attention
Why this matters
Standard Transformers learn position embeddings — fixed-vocabulary lookups added to the input. They work, but they generalize badly: a model trained at length 2k usually breaks at length 8k because those longer position embeddings were never seen during training.
ALiBi (Press, Smith, & Lewis, 2022 — “Train Short, Test Long”) eliminates position embeddings entirely. Instead, attention scores get a per-head linear bias proportional to the query-key distance:
bias[h, i, j] = -m_h * |i - j|
Closer keys get a smaller penalty; far ones a larger one. Because
the bias is added to scores BEFORE softmax, the model gets a built-in
“recency” prior — head h softens or sharpens it via the slope
m_h. ALiBi-trained models extrapolate to context lengths much
longer than training without retraining or fine-tuning.
The slopes
Each head gets a different slope m_h. Press et al. proposed a
fixed geometric sequence:
m_h = 2^(-8 * (h + 1) / H)
For H = 8: slopes are 2^(-1), 2^(-2), ..., 2^(-8) = [0.5, 0.25, ..., 1/256].
For H = 16: slopes interpolate to 2^(-0.5), 2^(-1.0), ..., 2^(-8).
Pure powers of 2 — no learning, no parameters.
Heads with large slopes (m_h ≈ 0.5) produce strong locality bias —
they specialize in nearby tokens. Heads with small slopes
(m_h ≈ 1/256) are nearly position-free — they look globally.
The set of heads spans a spectrum of attention “ranges.”
The distance matrix
i = jnp.arange(T)[:, None]
j = jnp.arange(T)[None, :]
dist = jnp.abs(i - j).astype(jnp.float32) # (T, T)
Per-head bias: shape (H, T, T):
bias = -slopes[:, None, None] * dist[None, :, :]
Each head’s slope multiplies the distance, with a leading minus. Then add to scores BEFORE softmax:
scores = scores + bias
weights = jax.nn.softmax(scores, axis=-1)
Note: ALiBi as described here uses the symmetric |i - j|. For
causal LMs the original paper actually masks j > i first, so
only past positions get the linear penalty. We’re keeping it
symmetric here (no causal mask) for clarity — the slope formula
and broadcasting choreography are the educational point.
Worked sketch
class AlibiMHA(nnx.Module):
# __init__: standard four nnx.Linear projections.
def __call__(self, x):
T, _ = x.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
h_idx = jnp.arange(H)
slopes = jnp.power(2.0, -8.0 * (h_idx + 1) / H)
i = jnp.arange(T)[:, None]
j = jnp.arange(T)[None, :]
dist = jnp.abs(i - j).astype(jnp.float32)
bias = -slopes[:, None, None] * dist[None, :, :]
scores = scores + bias
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
Why this extrapolates
Standard absolute position embeddings give the model only a finite
“alphabet” of positions; novel positions don’t have embeddings.
ALiBi’s bias is purely a function of |i - j|, defined for any
distance. A model trained on T = 1024 can attend over T = 16384
at inference and the bias structure looks the same. No new
parameters, no out-of-distribution embeddings.
Common pitfalls
-
Forgetting the minus sign in front of slopes. ALiBi penalizes
far tokens; the bias should be negative.
bias = slopes * dist(no minus) inverts the prior — far tokens become PREFERRED. -
Slopes per-head wrong. The
+1in(h + 1)matters. Withh_idx = 0..H-1, the formula2^(-8 * (h+1)/H)gives slopes from2^(-8/H)down to2^(-8). Off-by-one shifts the whole spectrum. - Adding bias after softmax. Bias must be added to scores (in the linear domain) before softmax. Adding to weights post-softmax breaks normalization.
-
Broadcasting the wrong way.
slopes[:, None, None]shape(H, 1, 1).dist[None, :, :]shape(1, T, T). Product is(H, T, T)— matchesscores.
Problem
Write mha_alibi(seed, x, num_heads, d_model):
-
Define
AlibiMHA(nnx.Module): standard fournnx.Linearprojections. -
__call__(x): scaled scores, then build per-head slopesm_h = 2^(-8 * (h + 1) / H), distance|i - j|, bias-slopes * distbroadcast to(H, T, T). Add to scores BEFORE softmax. Continue with SDPA. -
Cast
num_heads, d_modelto int. Return flattened.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model: ints (passed as floats).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.