We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ALiBi: Attention with Linear Biases
Why this matters
Most position encodings (sinusoidal, learned, RoPE) inject position information INTO the embeddings. ALiBi (Press et al., 2022) takes a different angle: don’t change the embeddings; just bias the attention scores so that distant tokens get penalized.
Specifically, the score (i, j) gets a bias -m_h * |i - j|, where
m_h is a per-head slope. Closer tokens are unaffected; far tokens
get progressively more negative scores → less attention.
Key practical advantage: ALiBi extrapolates. A model trained on 2k context can run at 8k inference and still attend reasonably, because the bias has the same shape at any length. RoPE and learned embeddings fall apart out-of-distribution.
BLOOM (176B) used ALiBi. So did MPT-7B/30B.
The slopes
Per-head slope: m_h = 2^(-8 * (h + 1) / H) for h in [0, H).
For H=8: slopes ≈ [0.5, 0.25, 0.125, 0.0625, 0.0313, 0.0156, 0.0078, 0.0039].
Geometric series — early heads decay attention rapidly with distance,
later heads barely decay (so they’re effectively position-agnostic).
The mix gives the model a built-in inductive bias toward locality
while keeping some heads “global”.
The bias matrix
bias[h, i, j] = -m_h * |i - j|. Shape (H, T, T).
Concretely for H=2, T=4:
slopes = [0.0625, 0.00390625]
bias[0] = -0.0625 *
[[0, 1, 2, 3],
[1, 0, 1, 2],
[2, 1, 0, 1],
[3, 2, 1, 0]]
bias[1] = -0.00390625 * (same matrix)
How to inject into Flax MHA
nn.MultiHeadDotProductAttention takes a mask= arg, but mask is
boolean (where false → -inf). We need ADDITIVE bias, which is a
different code path.
The underlying flax.linen.attention.dot_product_attention accepts a
bias= arg directly. So we wrap it with a closure that pre-binds our
bias, and pass that as attention_fn=:
from flax.linen.attention import dot_product_attention
def alibi_fn(query, key, value, bias_=None, mask=None, **kwargs):
# Inject our pre-computed ALiBi bias.
return dot_product_attention(query, key, value,
bias=alibi_bias, mask=mask, **kwargs)
attn = nn.MultiHeadDotProductAttention(num_heads=H,
qkv_features=D,
attention_fn=alibi_fn)
attention_fn is a constructor kwarg of nn.MultiHeadDotProductAttention
— replace the default SDPA with one that always adds your bias.
Bias shape gotcha
Internally Flax computes scores with shape (batch, H, T_q, T_k) (or
(H, T_q, T_k) if you pass a 2-D input — it’ll broadcast). The bias
needs to broadcast against that. Our (H, T, T) works because it
matches the trailing three dims; the missing leading batch dim broadcasts.
Computing |i - j|
i = jnp.arange(T)[None, :, None] # shape (1, T, 1)
j = jnp.arange(T)[None, None, :] # shape (1, 1, T)
abs_diff = jnp.abs(i - j).astype(jnp.float32) # (1, T, T)
bias = -slopes[:, None, None] * abs_diff # (H, T, T)
Common pitfalls
-
Using
mask=instead ofbias=— Flax masks are boolean, they’d zero ALiBi’s nuance.attention_fninjection via a closure is the right path. -
Wrong slope formula —
2 ** (-8 * (h+1) / H)is the canonical one. Some implementations use head indexhdirectly without the +1. Pick one and be consistent. -
Forgetting the negative sign —
bias = -slopes * |i-j|. Without the negation, near tokens are punished and far tokens favored — anti-locality.
Problem
Implement mha_alibi(seed, x, num_heads, qkv_features):
-
Compute slopes:
slopes = 2 ** (-8 * (jnp.arange(H) + 1) / H). -
Build the bias: shape
(H, T, T)withbias[h, i, j] = -slopes[h] * |i - j|. -
Wrap
dot_product_attentionto inject the bias. -
Build MHA with
attention_fn=alibi_fn. Init, apply, return flat.
Inputs:
-
seed: int. -
x: 2-D(T, D_in). -
num_heads,qkv_features: ints.
Output: 1-D, the flattened ALiBi-attention output.
Hints
Sign in to attempt this problem and view the solution.