We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Sliding-Window Attention
Why this matters
Standard self-attention has O(T^2) cost — it doubles every time
you halve the sequence length. For a 32k context, that’s a billion
score entries per head. Sliding-window attention bounds the
receptive field per token to a fixed window of W tokens, giving
O(T * W) cost. With W = 4096, scaling to T = 1M is no longer
quadratic.
Mistral 7B and Mistral Mixtral both use sliding-window attention
with W = 4096. Combined with stacked layers, each token’s
effective receptive field grows to W * L (where L is the
layer count) — far enough to model long contexts while keeping
per-layer cost linear.
The band mask
Position i attends only to keys in [i - W + 1, i] — itself and
the W - 1 most recent past tokens. Equivalently, query i
attends to key j iff i - j >= 0 AND i - j < W.
Build the mask by computing the index difference matrix:
i = jnp.arange(T)[:, None] # column vector of query indices
j = jnp.arange(T)[None, :] # row vector of key indices
diff = i - j # (T, T) — row i, col j has i - j
in_band = (diff >= 0) & (diff < window) # boolean (T, T)
diff[i, j] = i - j:
-
i - j > W - 1(too far in the past): NOT in band, mask out. -
0 <= i - j <= W - 1: in band, keep. -
i - j < 0(future): NOT in band — mask out (this is the causal part).
So sliding-window IS causal — the band is one-sided, looking
only backward. The single condition (diff >= 0) & (diff < W)
handles both causal and locality at once.
Apply BEFORE softmax
scores = jnp.where(in_band, scores, -1e9)
weights = jax.nn.softmax(scores, axis=-1)
Same -1e9 trick as the causal mask. Use a finite negative number,
not -jnp.inf, to keep autodiff happy.
Worked sketch
class SlidingWindowMHA(nnx.Module):
# __init__: same four nnx.Linear projections as plain MHA.
def __call__(self, x, window):
T, _ = x.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
i = jnp.arange(T)[:, None]
j = jnp.arange(T)[None, :]
diff = i - j
in_band = (diff >= 0) & (diff < window)
scores = jnp.where(in_band, scores, -1e9)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
The mask in_band has shape (T, T) — broadcasts over the head
axis automatically, just like the causal mask.
Worked example
With T = 5, W = 3:
in_band:
[[1, 0, 0, 0, 0] # pos 0 sees only pos 0
[1, 1, 0, 0, 0] # pos 1 sees pos {0, 1}
[1, 1, 1, 0, 0] # pos 2 sees pos {0, 1, 2}
[0, 1, 1, 1, 0] # pos 3 sees pos {1, 2, 3}
[0, 0, 1, 1, 1]] # pos 4 sees pos {2, 3, 4}
Past W = 3 tokens, the receptive field slides forward —
position 3 no longer sees position 0.
What this achieves
Each token does W (not T) dot products in the attention layer.
Cost goes from O(T^2) to O(T * W). For Mistral with W = 4096
and T = 32k, that’s an 8x reduction in attention cost (and 8x
less KV-cache read per step at inference).
The KV cache also gets smaller in principle: only the last W
tokens matter at each step. Mistral implements this as a circular
buffer.
Common pitfalls
-
Sign of
i - j. If you writediff = j - iyou get the transpose — looking forward instead of back. Check by printingin_bandfor small T. -
Off-by-one.
i - j < W(strict) gives a window of size W including the current token.i - j <= Wgives W + 1. -
Dropping the causal half.
abs(diff) < Wwould build a symmetric band — letting tokens peek forward. Sliding-window attention is causal AND windowed. -
windowarriving as float. Cast to int.
Problem
Write mha_sliding_window(seed, x, num_heads, d_model, window):
-
Define
SlidingWindowMHA(nnx.Module)with the standard four projections. -
__call__(x, window): compute scaled scores, buildin_bandfrom(diff >= 0) & (diff < window), fill the rest with-1e9, softmax, value matmul,out_proj. - Cast every dimension arg to int. Return flattened.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model,window: ints (passed as floats).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.