We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Sliding-Window Attention (Mistral-style)
Why this matters
Standard attention is O(T²) per layer — every token attends to every
other token. At long contexts (32k+), this is painful. Sliding-window
attention restricts each position to attend ONLY to the previous
W positions (and itself). Cost drops to O(T*W) — linear in T for
fixed window size.
Mistral 7B made this famous: window = 4096, but stacking many layers creates an effective receptive field much larger than W (each layer extends the window by W more tokens “deeper”). Longformer (Beltagy et al., 2020) used the same trick with full attention on a few “global” tokens.
The mask
Position i attends to positions j where i - W < j ≤ i.
Equivalently: 0 ≤ i - j < W.
i = jnp.arange(T)[:, None] # shape (T, 1)
j = jnp.arange(T)[None, :] # shape (1, T)
mask = ((i - j) >= 0) & ((i - j) < W)
mask = mask.astype(jnp.float32)
For T=6, W=3:
mask =
[[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1]]
Causal AND windowed: row i is 1 only on columns [max(0, i-W+1), i].
Position 5 sees only {3, 4, 5} — its own three-token window.
Note: this includes the diagonal (i=j → 0 < W, true), so position
i always attends to itself (assuming W >= 1).
Plug into Flax
Same as causal MHA, just with the band mask instead of tril:
attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
params = attn.init(rng, x, mask=mask)
out = attn.apply(params, x, mask=mask)
Flax broadcasts the (T, T) mask over batch and heads automatically.
Why W < T matters
If W >= T, the mask becomes purely causal (lower triangular). The
“window” doesn’t kick in until W < T. For Mistral, T can be
32,768 and W is 4,096 — so the window blocks 7/8 of the score
matrix.
Stacking gives effective long-range
A single layer with W=4 only sees 4 tokens of context. But with L
layers, the deepest layer’s representation at position i was
influenced by positions [i - L*W + 1, i] because every layer slides
one more window-width back. This is how Mistral handles 32k+ context
with a 4k window.
Common pitfalls
-
Off-by-one in the band:
(i - j) <= W - 1is correct for “self + previous W-1 tokens”. Using< Wgives the same result. Using<= Wgives W+1 positions — too many. -
Including future tokens: omit the
(i - j) >= 0clause and you get a symmetric window — local attention with no causality. For an autoregressive model that’s wrong; for an encoder it’s fine (Longformer’s encoder uses a symmetric window). -
Wrong mask shape: 3-D doesn’t work — 2-D
(T, T)does.
Problem
Implement mha_sliding_window(seed, x, num_heads, qkv_features, window):
-
T = x.shape[0]. Build the band mask:mask = ((i - j) >= 0) & ((i - j) < W)cast to float32. -
Apply MHA with
mask=. -
Return
out.reshape(-1).
Inputs:
-
seed: int. -
x: 2-D(T, D_in). -
num_heads,qkv_features: ints. -
window: int W. Window size including self.
Output: 1-D, the windowed attention output flattened.
Hints
Sign in to attempt this problem and view the solution.