hard primitives

ALiBi Bias Matrix

Why this matters

ALiBi (Attention with Linear Biases, Press et al. 2022) drops position embeddings entirely. Instead, it adds a fixed, per-head linear penalty to the attention scores, biasing each query against far-away keys. Concretely, the score q · k between query i and key j is nudged by -m_h · |i - j| for head h. The slope m_h is fixed (not learned) and differs per head — some heads attend locally, some globally.

Why does this work? Distance penalties are basically a soft causal / locality prior. Combined with attention’s existing flexibility, ALiBi matches sinusoidal/learned positions on perplexity and extrapolates well to longer sequences at inference. BLOOM and MPT use ALiBi.

The formula (causal variant)

For H heads and sequence length T, the bias tensor bias[h, i, j] has shape (H, T, T):

bias[h, i, j] = -m_h * (i - j)    if j <= i  (allowed positions)
bias[h, i, j] = -inf              if j > i   (causal masked)

Note i - j ≥ 0 for allowed positions, so the bias is always non-positive on the lower triangle; it’s 0 on the diagonal and most negative far back.

Slopes

ALiBi uses a geometric sequence of slopes — head 0 the strongest, head H-1 the weakest:

m_h = 2^(-8 * (h + 1) / H)        for h ∈ [0, H)

Examples:

  • H=2: m = [2^-4, 2^-8] = [1/16, 1/256].
  • H=8: m = [2^-1, 2^-2, ..., 2^-8].

Strong slopes (m ≈ 1/2) make a head attend extremely locally; weak slopes (m ≈ 1/256) barely affect attention, so that head looks globally.

Use of -1e9 not -jnp.inf

For masked positions we use -1e9, not -jnp.inf. Two reasons:

  1. After softmax, -1e9 is effectively zero (exp(-1e9) underflows to 0) — same effect as -inf.
  2. jnp.inf does NOT survive JSON round-trips through this platform’s expected-output serialiser (it gets dropped or NaN’d). -1e9 does.

All ALiBi reference implementations in production frameworks use a large negative number for this reason.

Vectorised build

h_idx  = jnp.arange(H) + 1.0                      # 1..H
slopes = 2.0 ** (-8.0 * h_idx / H)                # (H,)
i = jnp.arange(T)[:, None]                        # (T, 1)
j = jnp.arange(T)[None, :]                        # (1, T)
distance = i - j                                  # (T, T) — symmetric magnitude
bias = -slopes[:, None, None] * distance[None]    # (H, T, T)
bias = jnp.where(j > i, -1e9, bias)               # causal mask

The broadcast -slopes[:, None, None] * distance[None] does the per-head expansion in one shot.

Common pitfalls

  • Slope formula off-by-one: it’s 2^(-8 * (h+1) / H), not 2^(-8 * h / H). Head 0 gets 2^(-8/H), not 1.
  • Sign of distance: bias is -m_h * (i - j). For j > i (future tokens) i - j < 0, so -m_h * (i - j) > 0 — that would encourage attending to the future. The causal mask overwrites those entries with -1e9, so it doesn’t matter, but get the sign right anyway.
  • jnp.inf instead of -1e9: see above — breaks the test harness.
  • Forgetting per-head broadcast: bias is (H, T, T), not (T, T). Each head gets its own slope.

Problem

Implement alibi_bias_matrix(seed, num_heads, T):

  1. Cast num_heads, T to int. (seed is unused.)
  2. Compute slopes m_h = 2^(-8 * (h + 1) / H) for h in [0, H).
  3. Build the (H, T, T) bias tensor with -m_h * (i - j) for j ≤ i and -1e9 for j > i.
  4. Return flattened.

Inputs:

  • seed: int (unused).
  • num_heads: int H.
  • T: int — sequence length.

Output: 1-D array of length H · T · T.

Hints

flax alibi position-encoding attention

Sign in to attempt this problem and view the solution.