We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
ALiBi Bias Matrix
Why this matters
ALiBi (Attention with Linear Biases, Press et al. 2022) drops
position embeddings entirely. Instead, it adds a fixed, per-head linear
penalty to the attention scores, biasing each query against far-away
keys. Concretely, the score q · k between query i and key j is
nudged by -m_h · |i - j| for head h. The slope m_h is fixed (not
learned) and differs per head — some heads attend locally, some globally.
Why does this work? Distance penalties are basically a soft causal / locality prior. Combined with attention’s existing flexibility, ALiBi matches sinusoidal/learned positions on perplexity and extrapolates well to longer sequences at inference. BLOOM and MPT use ALiBi.
The formula (causal variant)
For H heads and sequence length T, the bias tensor bias[h, i, j]
has shape (H, T, T):
bias[h, i, j] = -m_h * (i - j) if j <= i (allowed positions)
bias[h, i, j] = -inf if j > i (causal masked)
Note i - j ≥ 0 for allowed positions, so the bias is always non-positive
on the lower triangle; it’s 0 on the diagonal and most negative far back.
Slopes
ALiBi uses a geometric sequence of slopes — head 0 the strongest, head
H-1 the weakest:
m_h = 2^(-8 * (h + 1) / H) for h ∈ [0, H)
Examples:
-
H=2:
m = [2^-4, 2^-8] = [1/16, 1/256]. -
H=8:
m = [2^-1, 2^-2, ..., 2^-8].
Strong slopes (m ≈ 1/2) make a head attend extremely locally; weak
slopes (m ≈ 1/256) barely affect attention, so that head looks globally.
Use of -1e9 not -jnp.inf
For masked positions we use -1e9, not -jnp.inf. Two reasons:
-
After softmax,
-1e9is effectively zero (exp(-1e9) underflows to 0) — same effect as-inf. -
jnp.infdoes NOT survive JSON round-trips through this platform’s expected-output serialiser (it gets dropped or NaN’d).-1e9does.
All ALiBi reference implementations in production frameworks use a large negative number for this reason.
Vectorised build
h_idx = jnp.arange(H) + 1.0 # 1..H
slopes = 2.0 ** (-8.0 * h_idx / H) # (H,)
i = jnp.arange(T)[:, None] # (T, 1)
j = jnp.arange(T)[None, :] # (1, T)
distance = i - j # (T, T) — symmetric magnitude
bias = -slopes[:, None, None] * distance[None] # (H, T, T)
bias = jnp.where(j > i, -1e9, bias) # causal mask
The broadcast -slopes[:, None, None] * distance[None] does the per-head
expansion in one shot.
Common pitfalls
-
Slope formula off-by-one: it’s
2^(-8 * (h+1) / H), not2^(-8 * h / H). Head 0 gets2^(-8/H), not1. -
Sign of distance: bias is
-m_h * (i - j). Forj > i(future tokens)i - j < 0, so-m_h * (i - j) > 0— that would encourage attending to the future. The causal mask overwrites those entries with-1e9, so it doesn’t matter, but get the sign right anyway. -
jnp.infinstead of-1e9: see above — breaks the test harness. -
Forgetting per-head broadcast:
biasis(H, T, T), not(T, T). Each head gets its own slope.
Problem
Implement alibi_bias_matrix(seed, num_heads, T):
-
Cast
num_heads,Ttoint. (seedis unused.) -
Compute slopes
m_h = 2^(-8 * (h + 1) / H)for h in[0, H). -
Build the
(H, T, T)bias tensor with-m_h * (i - j)forj ≤ iand-1e9forj > i. - Return flattened.
Inputs:
-
seed: int (unused). -
num_heads: int H. -
T: int — sequence length.
Output: 1-D array of length H · T · T.
Hints
Sign in to attempt this problem and view the solution.