hard primitives

ALiBi: Attention with Linear Biases

Why this matters

Most position encodings (sinusoidal, learned, RoPE) inject position information INTO the embeddings. ALiBi (Press et al., 2022) takes a different angle: don’t change the embeddings; just bias the attention scores so that distant tokens get penalized.

Specifically, the score (i, j) gets a bias -m_h * |i - j|, where m_h is a per-head slope. Closer tokens are unaffected; far tokens get progressively more negative scores → less attention.

Key practical advantage: ALiBi extrapolates. A model trained on 2k context can run at 8k inference and still attend reasonably, because the bias has the same shape at any length. RoPE and learned embeddings fall apart out-of-distribution.

BLOOM (176B) used ALiBi. So did MPT-7B/30B.

The slopes

Per-head slope: m_h = 2^(-8 * (h + 1) / H) for h in [0, H).

For H=8: slopes ≈ [0.5, 0.25, 0.125, 0.0625, 0.0313, 0.0156, 0.0078, 0.0039]. Geometric series — early heads decay attention rapidly with distance, later heads barely decay (so they’re effectively position-agnostic). The mix gives the model a built-in inductive bias toward locality while keeping some heads “global”.

The bias matrix

bias[h, i, j] = -m_h * |i - j|. Shape (H, T, T).

Concretely for H=2, T=4:

slopes = [0.0625, 0.00390625]
bias[0] = -0.0625 *
  [[0, 1, 2, 3],
   [1, 0, 1, 2],
   [2, 1, 0, 1],
   [3, 2, 1, 0]]
bias[1] = -0.00390625 * (same matrix)

How to inject into Flax MHA

nn.MultiHeadDotProductAttention takes a mask= arg, but mask is boolean (where false → -inf). We need ADDITIVE bias, which is a different code path.

The underlying flax.linen.attention.dot_product_attention accepts a bias= arg directly. So we wrap it with a closure that pre-binds our bias, and pass that as attention_fn=:

from flax.linen.attention import dot_product_attention

def alibi_fn(query, key, value, bias_=None, mask=None, **kwargs):
    # Inject our pre-computed ALiBi bias.
    return dot_product_attention(query, key, value,
                                 bias=alibi_bias, mask=mask, **kwargs)

attn = nn.MultiHeadDotProductAttention(num_heads=H,
                                       qkv_features=D,
                                       attention_fn=alibi_fn)

attention_fn is a constructor kwarg of nn.MultiHeadDotProductAttention — replace the default SDPA with one that always adds your bias.

Bias shape gotcha

Internally Flax computes scores with shape (batch, H, T_q, T_k) (or (H, T_q, T_k) if you pass a 2-D input — it’ll broadcast). The bias needs to broadcast against that. Our (H, T, T) works because it matches the trailing three dims; the missing leading batch dim broadcasts.

Computing |i - j|

i = jnp.arange(T)[None, :, None]   # shape (1, T, 1)
j = jnp.arange(T)[None, None, :]   # shape (1, 1, T)
abs_diff = jnp.abs(i - j).astype(jnp.float32)   # (1, T, T)
bias = -slopes[:, None, None] * abs_diff        # (H, T, T)

Common pitfalls

  • Using mask= instead of bias= — Flax masks are boolean, they’d zero ALiBi’s nuance. attention_fn injection via a closure is the right path.
  • Wrong slope formula2 ** (-8 * (h+1) / H) is the canonical one. Some implementations use head index h directly without the +1. Pick one and be consistent.
  • Forgetting the negative signbias = -slopes * |i-j|. Without the negation, near tokens are punished and far tokens favored — anti-locality.

Problem

Implement mha_alibi(seed, x, num_heads, qkv_features):

  1. Compute slopes: slopes = 2 ** (-8 * (jnp.arange(H) + 1) / H).
  2. Build the bias: shape (H, T, T) with bias[h, i, j] = -slopes[h] * |i - j|.
  3. Wrap dot_product_attention to inject the bias.
  4. Build MHA with attention_fn=alibi_fn. Init, apply, return flat.

Inputs:

  • seed: int.
  • x: 2-D (T, D_in).
  • num_heads, qkv_features: ints.

Output: 1-D, the flattened ALiBi-attention output.

Hints

flax attention alibi position-encoding

Sign in to attempt this problem and view the solution.