We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Block-Diagonal Attention Mask
Why this matters
Block-diagonal (a.k.a. chunked) attention partitions the sequence into
contiguous blocks of size B, and lets each position attend only
within its own block. The score matrix becomes block-diagonal: each
B×B square sums to 1 along its row, everything outside is zero.
Use cases:
- Sparse-attention papers (BigBird, Sparse Transformer): block attention is one of the building blocks, combined with global tokens and random tokens to get linear complexity with full receptive field over many layers.
- Document boundaries: when packing multiple short docs into one training sequence, you want token N+1 (start of doc 2) to NOT attend to docs 1’s tokens. Set block_size = doc length per packed doc; you’ve isolated docs.
- Long-context model serving: process a long sequence in independent chunks at low cost, then a separate “global” layer mixes information across chunks.
Building the mask
Position i is in block i // B. Position i attends to position
j iff they’re in the same block.
i = jnp.arange(T)
same_block = (i[:, None] // B) == (i[None, :] // B)
mask = same_block.astype(jnp.float32)
For T=6, B=2:
mask =
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1]]
Pure block diagonal — three 2×2 blocks of ones, zeros elsewhere.
Cost reduction
With T tokens in T/B blocks:
-
Naive attention:
T²score-matrix entries, all evaluated. -
Block attention:
T * Bnon-zero entries —T/Btimes less.
For T=4096, B=128: 32× cheaper. Beats dense attention as long as
inter-block info loss is acceptable for the task.
Edge cases
-
B=1: each position is its own block — every position attends ONLY to itself. The mask is the identity matrix. Output is justV_i = projection(x_i)per position (no mixing). -
B >= T: every position is in block 0 — equivalent to no mask. -
T % B != 0: the final block is smaller; the integer division still produces a sensible mask. No padding required.
Compared to other masks in this track
-
Causal (
tril): row i sees[0, i]. Information flows forward. -
Sliding window: row i sees
[i - W + 1, i]. Local attention with causality. -
Block-diagonal (this one): row i sees
[block_start, block_end]regardless of i’s position in the block. NO information leak across block boundaries — neither past nor future.
Common pitfalls
-
(i // B) == (j // B)— both sides need integer division. In JAX,//works on int arrays directly; no need to cast. -
Wrong axes for the
==— usei[:, None]vsi[None, :]to get a(T, T)matrix.i == jgives a 1-D length-T boolean. -
Leftover float mask of
True/False: cast tofloat32before passing to Flax MHA — strict equality with bool sometimes triggers surprising broadcasting under jit.
Problem
Implement mha_masked_block(seed, x, num_heads, qkv_features, block_size):
-
T = x.shape[0]. Buildsame_block = (i[:, None] // B) == (i[None, :] // B). -
Cast to
float32. -
Apply
nn.MultiHeadDotProductAttentionwithmask=set. -
Return
out.reshape(-1).
Inputs:
-
seed: int. -
x: 2-D(T, D_in). -
num_heads,qkv_features: ints. -
block_size: int B.
Output: 1-D, the flattened block-attention output.
Hints
Sign in to attempt this problem and view the solution.