We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
T5 Relative Position Bucketing
Why this matters
T5 (Raffel et al. 2020) introduced a compact, learned relative position
bias: each (query_pos, key_pos) pair maps to one of num_buckets
bucket indices, and the model learns one scalar per bucket per head. The
bucketing scheme is the clever part β half the buckets are EXACT for
small distances, the rest are LOGARITHMICALLY spaced up to
max_distance. So the model treats nearby tokens precisely and far-away
tokens coarsely.
The headline benefit: a model trained at one sequence length naturally handles other lengths at eval time, because the bucket function is deterministic (not a learned per-position vector).
How the buckets are laid out (bidirectional self-attention)
With num_buckets=N (even):
-
Half the buckets are for negative offsets (
key_pos < query_pos), the other half for positive offsets. Encoded by a+ N/2boost whenrelative_position > 0. -
Within each half (size
N/2):-
The first
N/4buckets are EXACT for distances[0, N/4). -
The remaining
N/4buckets are LOGARITHMICALLY spaced fromN/4tomax_distance.
-
The first
So with N=32, max_distance=128: buckets 0β7 are exact, 8β15 are log-
spaced for negative offsets; 16β23 are exact, 24β31 are log-spaced for
positive offsets.
The algorithm (T5 / Hugging Face reference)
def t5_bucket(relative_position, num_buckets=32, max_distance=128):
relative_buckets = 0
nb_half = num_buckets // 2
relative_buckets += (relative_position > 0).astype(jnp.int32) * nb_half
relative_position = jnp.abs(relative_position)
max_exact = nb_half // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
jnp.log(relative_position.astype(jnp.float32) / max_exact)
/ jnp.log(max_distance / max_exact)
* (nb_half - max_exact)
).astype(jnp.int32)
relative_position_if_large = jnp.minimum(
relative_position_if_large, nb_half - 1
)
relative_buckets += jnp.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
Read it slowly. The first add encodes the SIGN. Then we abs the distance
and either keep it (if small) or log-bucket it. The log scaling maps
[max_exact, max_distance] linearly onto [max_exact, nb_half-1].
The bucket function is independently applied to each (i, j) pair in
the (T, T) matrix where relative_position[i, j] = j - i.
Worked example
T=4, num_buckets=8, max_distance=16. Then nb_half = 4, max_exact = 2.
Distances form the matrix j - i:
[ 0 1 2 3]
[-1 0 1 2]
[-2 -1 0 1]
[-3 -2 -1 0]
For each cell:
-
j > i(positive): add 4 (the half-boost). -
|d| = 0 or 1β small, bucket = abs(d). -
|d| β₯ 2β log-bucketed. With max_exact=2 and nb_half-max_exact=2, the log map sends 2 β 2, 3 β 2 (rounded down).
Result: [[0, 5, 6, 6], [1, 0, 5, 6], [2, 1, 0, 5], [2, 2, 1, 0]].
Common pitfalls
-
jnp.log(0)atrelative_position=0is-inf. JAXβswhereevaluates both branches eagerly, so therelative_position_if_largecomputation will produce inf/NaN β butis_smallis True there sowherepicks the safe branch. It works in practice; just donβt be alarmed by warnings. -
intvsint32: keep distances asint32when bucketing; cast tofloat32only at the end. -
max_exact = num_buckets // 4, notnum_buckets // 2. Itβs a quarter of the total, half of one direction. -
Off-by-one on the clamp: clamp to
nb_half - 1, notnb_half, so the bucket index stays in[0, nb_half).
Problem
Implement t5_relative_pos(seed, T, num_buckets, max_distance):
-
Build the
(T, T)integer matrixrelative_position[i, j] = j - i. - Apply the T5 bucketing function above to each entry.
-
Cast to
float32and return flattened.
seed is unused; kept for signature consistency.
Inputs:
-
seed: int (unused). -
T: int β sequence length. -
num_buckets: int N, even (typically 32). -
max_distance: int β typically 128.
Output: 1-D float array of length T Β· T.
Hints
Sign in to attempt this problem and view the solution.