hard primitives

Sliding-Window Attention (Mistral-style)

Why this matters

Standard attention is O(T²) per layer — every token attends to every other token. At long contexts (32k+), this is painful. Sliding-window attention restricts each position to attend ONLY to the previous W positions (and itself). Cost drops to O(T*W) — linear in T for fixed window size.

Mistral 7B made this famous: window = 4096, but stacking many layers creates an effective receptive field much larger than W (each layer extends the window by W more tokens “deeper”). Longformer (Beltagy et al., 2020) used the same trick with full attention on a few “global” tokens.

The mask

Position i attends to positions j where i - W < j ≤ i. Equivalently: 0 ≤ i - j < W.

i = jnp.arange(T)[:, None]   # shape (T, 1)
j = jnp.arange(T)[None, :]   # shape (1, T)
mask = ((i - j) >= 0) & ((i - j) < W)
mask = mask.astype(jnp.float32)

For T=6, W=3:

mask =
[[1, 0, 0, 0, 0, 0],
 [1, 1, 0, 0, 0, 0],
 [1, 1, 1, 0, 0, 0],
 [0, 1, 1, 1, 0, 0],
 [0, 0, 1, 1, 1, 0],
 [0, 0, 0, 1, 1, 1]]

Causal AND windowed: row i is 1 only on columns [max(0, i-W+1), i]. Position 5 sees only {3, 4, 5} — its own three-token window.

Note: this includes the diagonal (i=j → 0 < W, true), so position i always attends to itself (assuming W >= 1).

Plug into Flax

Same as causal MHA, just with the band mask instead of tril:

attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
params = attn.init(rng, x, mask=mask)
out = attn.apply(params, x, mask=mask)

Flax broadcasts the (T, T) mask over batch and heads automatically.

Why W < T matters

If W >= T, the mask becomes purely causal (lower triangular). The “window” doesn’t kick in until W < T. For Mistral, T can be 32,768 and W is 4,096 — so the window blocks 7/8 of the score matrix.

Stacking gives effective long-range

A single layer with W=4 only sees 4 tokens of context. But with L layers, the deepest layer’s representation at position i was influenced by positions [i - L*W + 1, i] because every layer slides one more window-width back. This is how Mistral handles 32k+ context with a 4k window.

Common pitfalls

  • Off-by-one in the band: (i - j) <= W - 1 is correct for “self + previous W-1 tokens”. Using < W gives the same result. Using <= W gives W+1 positions — too many.
  • Including future tokens: omit the (i - j) >= 0 clause and you get a symmetric window — local attention with no causality. For an autoregressive model that’s wrong; for an encoder it’s fine (Longformer’s encoder uses a symmetric window).
  • Wrong mask shape: 3-D doesn’t work — 2-D (T, T) does.

Problem

Implement mha_sliding_window(seed, x, num_heads, qkv_features, window):

  1. T = x.shape[0]. Build the band mask: mask = ((i - j) >= 0) & ((i - j) < W) cast to float32.
  2. Apply MHA with mask=.
  3. Return out.reshape(-1).

Inputs:

  • seed: int.
  • x: 2-D (T, D_in).
  • num_heads, qkv_features: ints.
  • window: int W. Window size including self.

Output: 1-D, the windowed attention output flattened.

Hints

flax attention sliding-window mistral

Sign in to attempt this problem and view the solution.