hard primitives

T5 Relative Position Bucketing

Why this matters

T5 (Raffel et al. 2020) introduced a compact, learned relative position bias: each (query_pos, key_pos) pair maps to one of num_buckets bucket indices, and the model learns one scalar per bucket per head. The bucketing scheme is the clever part β€” half the buckets are EXACT for small distances, the rest are LOGARITHMICALLY spaced up to max_distance. So the model treats nearby tokens precisely and far-away tokens coarsely.

The headline benefit: a model trained at one sequence length naturally handles other lengths at eval time, because the bucket function is deterministic (not a learned per-position vector).

How the buckets are laid out (bidirectional self-attention)

With num_buckets=N (even):

  1. Half the buckets are for negative offsets (key_pos < query_pos), the other half for positive offsets. Encoded by a + N/2 boost when relative_position > 0.
  2. Within each half (size N/2):
    • The first N/4 buckets are EXACT for distances [0, N/4).
    • The remaining N/4 buckets are LOGARITHMICALLY spaced from N/4 to max_distance.

So with N=32, max_distance=128: buckets 0–7 are exact, 8–15 are log- spaced for negative offsets; 16–23 are exact, 24–31 are log-spaced for positive offsets.

The algorithm (T5 / Hugging Face reference)

def t5_bucket(relative_position, num_buckets=32, max_distance=128):
    relative_buckets = 0
    nb_half = num_buckets // 2
    relative_buckets += (relative_position > 0).astype(jnp.int32) * nb_half
    relative_position = jnp.abs(relative_position)
    max_exact = nb_half // 2
    is_small = relative_position < max_exact
    relative_position_if_large = max_exact + (
        jnp.log(relative_position.astype(jnp.float32) / max_exact)
        / jnp.log(max_distance / max_exact)
        * (nb_half - max_exact)
    ).astype(jnp.int32)
    relative_position_if_large = jnp.minimum(
        relative_position_if_large, nb_half - 1
    )
    relative_buckets += jnp.where(
        is_small, relative_position, relative_position_if_large
    )
    return relative_buckets

Read it slowly. The first add encodes the SIGN. Then we abs the distance and either keep it (if small) or log-bucket it. The log scaling maps [max_exact, max_distance] linearly onto [max_exact, nb_half-1].

The bucket function is independently applied to each (i, j) pair in the (T, T) matrix where relative_position[i, j] = j - i.

Worked example

T=4, num_buckets=8, max_distance=16. Then nb_half = 4, max_exact = 2.

Distances form the matrix j - i:

[ 0  1  2  3]
[-1  0  1  2]
[-2 -1  0  1]
[-3 -2 -1  0]

For each cell:

  • j > i (positive): add 4 (the half-boost).
  • |d| = 0 or 1 β†’ small, bucket = abs(d).
  • |d| β‰₯ 2 β†’ log-bucketed. With max_exact=2 and nb_half-max_exact=2, the log map sends 2 β†’ 2, 3 β†’ 2 (rounded down).

Result: [[0, 5, 6, 6], [1, 0, 5, 6], [2, 1, 0, 5], [2, 2, 1, 0]].

Common pitfalls

  • jnp.log(0) at relative_position=0 is -inf. JAX’s where evaluates both branches eagerly, so the relative_position_if_large computation will produce inf/NaN β€” but is_small is True there so where picks the safe branch. It works in practice; just don’t be alarmed by warnings.
  • int vs int32: keep distances as int32 when bucketing; cast to float32 only at the end.
  • max_exact = num_buckets // 4, not num_buckets // 2. It’s a quarter of the total, half of one direction.
  • Off-by-one on the clamp: clamp to nb_half - 1, not nb_half, so the bucket index stays in [0, nb_half).

Problem

Implement t5_relative_pos(seed, T, num_buckets, max_distance):

  1. Build the (T, T) integer matrix relative_position[i, j] = j - i.
  2. Apply the T5 bucketing function above to each entry.
  3. Cast to float32 and return flattened.

seed is unused; kept for signature consistency.

Inputs:

  • seed: int (unused).
  • T: int β€” sequence length.
  • num_buckets: int N, even (typically 32).
  • max_distance: int β€” typically 128.

Output: 1-D float array of length T Β· T.

Hints

flax t5 position-encoding transformers

Sign in to attempt this problem and view the solution.