hard primitives

NNX ALiBi Attention

Why this matters

Standard Transformers learn position embeddings — fixed-vocabulary lookups added to the input. They work, but they generalize badly: a model trained at length 2k usually breaks at length 8k because those longer position embeddings were never seen during training.

ALiBi (Press, Smith, & Lewis, 2022 — “Train Short, Test Long”) eliminates position embeddings entirely. Instead, attention scores get a per-head linear bias proportional to the query-key distance:

bias[h, i, j] = -m_h * |i - j|

Closer keys get a smaller penalty; far ones a larger one. Because the bias is added to scores BEFORE softmax, the model gets a built-in “recency” prior — head h softens or sharpens it via the slope m_h. ALiBi-trained models extrapolate to context lengths much longer than training without retraining or fine-tuning.

The slopes

Each head gets a different slope m_h. Press et al. proposed a fixed geometric sequence:

m_h = 2^(-8 * (h + 1) / H)

For H = 8: slopes are 2^(-1), 2^(-2), ..., 2^(-8) = [0.5, 0.25, ..., 1/256]. For H = 16: slopes interpolate to 2^(-0.5), 2^(-1.0), ..., 2^(-8). Pure powers of 2 — no learning, no parameters.

Heads with large slopes (m_h ≈ 0.5) produce strong locality bias — they specialize in nearby tokens. Heads with small slopes (m_h ≈ 1/256) are nearly position-free — they look globally. The set of heads spans a spectrum of attention “ranges.”

The distance matrix

i = jnp.arange(T)[:, None]
j = jnp.arange(T)[None, :]
dist = jnp.abs(i - j).astype(jnp.float32)   # (T, T)

Per-head bias: shape (H, T, T):

bias = -slopes[:, None, None] * dist[None, :, :]

Each head’s slope multiplies the distance, with a leading minus. Then add to scores BEFORE softmax:

scores = scores + bias
weights = jax.nn.softmax(scores, axis=-1)

Note: ALiBi as described here uses the symmetric |i - j|. For causal LMs the original paper actually masks j > i first, so only past positions get the linear penalty. We’re keeping it symmetric here (no causal mask) for clarity — the slope formula and broadcasting choreography are the educational point.

Worked sketch

class AlibiMHA(nnx.Module):
    # __init__: standard four nnx.Linear projections.

    def __call__(self, x):
        T, _ = x.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        h_idx = jnp.arange(H)
        slopes = jnp.power(2.0, -8.0 * (h_idx + 1) / H)
        i = jnp.arange(T)[:, None]
        j = jnp.arange(T)[None, :]
        dist = jnp.abs(i - j).astype(jnp.float32)
        bias = -slopes[:, None, None] * dist[None, :, :]
        scores = scores + bias
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

Why this extrapolates

Standard absolute position embeddings give the model only a finite “alphabet” of positions; novel positions don’t have embeddings. ALiBi’s bias is purely a function of |i - j|, defined for any distance. A model trained on T = 1024 can attend over T = 16384 at inference and the bias structure looks the same. No new parameters, no out-of-distribution embeddings.

Common pitfalls

  • Forgetting the minus sign in front of slopes. ALiBi penalizes far tokens; the bias should be negative. bias = slopes * dist (no minus) inverts the prior — far tokens become PREFERRED.
  • Slopes per-head wrong. The +1 in (h + 1) matters. With h_idx = 0..H-1, the formula 2^(-8 * (h+1)/H) gives slopes from 2^(-8/H) down to 2^(-8). Off-by-one shifts the whole spectrum.
  • Adding bias after softmax. Bias must be added to scores (in the linear domain) before softmax. Adding to weights post-softmax breaks normalization.
  • Broadcasting the wrong way. slopes[:, None, None] shape (H, 1, 1). dist[None, :, :] shape (1, T, T). Product is (H, T, T) — matches scores.

Problem

Write mha_alibi(seed, x, num_heads, d_model):

  1. Define AlibiMHA(nnx.Module): standard four nnx.Linear projections.
  2. __call__(x): scaled scores, then build per-head slopes m_h = 2^(-8 * (h + 1) / H), distance |i - j|, bias -slopes * dist broadcast to (H, T, T). Add to scores BEFORE softmax. Continue with SDPA.
  3. Cast num_heads, d_model to int. Return flattened.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model: ints (passed as floats).

Output: 1-D flattened.

Hints

flax nnx attention alibi positional transformers

Sign in to attempt this problem and view the solution.