medium primitives

Block-Diagonal Attention Mask

Why this matters

Block-diagonal (a.k.a. chunked) attention partitions the sequence into contiguous blocks of size B, and lets each position attend only within its own block. The score matrix becomes block-diagonal: each B×B square sums to 1 along its row, everything outside is zero.

Use cases:

  • Sparse-attention papers (BigBird, Sparse Transformer): block attention is one of the building blocks, combined with global tokens and random tokens to get linear complexity with full receptive field over many layers.
  • Document boundaries: when packing multiple short docs into one training sequence, you want token N+1 (start of doc 2) to NOT attend to docs 1’s tokens. Set block_size = doc length per packed doc; you’ve isolated docs.
  • Long-context model serving: process a long sequence in independent chunks at low cost, then a separate “global” layer mixes information across chunks.

Building the mask

Position i is in block i // B. Position i attends to position j iff they’re in the same block.

i = jnp.arange(T)
same_block = (i[:, None] // B) == (i[None, :] // B)
mask = same_block.astype(jnp.float32)

For T=6, B=2:

mask =
[[1, 1, 0, 0, 0, 0],
 [1, 1, 0, 0, 0, 0],
 [0, 0, 1, 1, 0, 0],
 [0, 0, 1, 1, 0, 0],
 [0, 0, 0, 0, 1, 1],
 [0, 0, 0, 0, 1, 1]]

Pure block diagonal — three 2×2 blocks of ones, zeros elsewhere.

Cost reduction

With T tokens in T/B blocks:

  • Naive attention: score-matrix entries, all evaluated.
  • Block attention: T * B non-zero entries — T/B times less.

For T=4096, B=128: 32× cheaper. Beats dense attention as long as inter-block info loss is acceptable for the task.

Edge cases

  • B=1: each position is its own block — every position attends ONLY to itself. The mask is the identity matrix. Output is just V_i = projection(x_i) per position (no mixing).
  • B >= T: every position is in block 0 — equivalent to no mask.
  • T % B != 0: the final block is smaller; the integer division still produces a sensible mask. No padding required.

Compared to other masks in this track

  • Causal (tril): row i sees [0, i]. Information flows forward.
  • Sliding window: row i sees [i - W + 1, i]. Local attention with causality.
  • Block-diagonal (this one): row i sees [block_start, block_end] regardless of i’s position in the block. NO information leak across block boundaries — neither past nor future.

Common pitfalls

  • (i // B) == (j // B) — both sides need integer division. In JAX, // works on int arrays directly; no need to cast.
  • Wrong axes for the == — use i[:, None] vs i[None, :] to get a (T, T) matrix. i == j gives a 1-D length-T boolean.
  • Leftover float mask of True/False: cast to float32 before passing to Flax MHA — strict equality with bool sometimes triggers surprising broadcasting under jit.

Problem

Implement mha_masked_block(seed, x, num_heads, qkv_features, block_size):

  1. T = x.shape[0]. Build same_block = (i[:, None] // B) == (i[None, :] // B).
  2. Cast to float32.
  3. Apply nn.MultiHeadDotProductAttention with mask= set.
  4. Return out.reshape(-1).

Inputs:

  • seed: int.
  • x: 2-D (T, D_in).
  • num_heads, qkv_features: ints.
  • block_size: int B.

Output: 1-D, the flattened block-attention output.

Hints

flax attention block-mask sparse-attention

Sign in to attempt this problem and view the solution.