hard primitives

NNX Sliding-Window Attention

Why this matters

Standard self-attention has O(T^2) cost — it doubles every time you halve the sequence length. For a 32k context, that’s a billion score entries per head. Sliding-window attention bounds the receptive field per token to a fixed window of W tokens, giving O(T * W) cost. With W = 4096, scaling to T = 1M is no longer quadratic.

Mistral 7B and Mistral Mixtral both use sliding-window attention with W = 4096. Combined with stacked layers, each token’s effective receptive field grows to W * L (where L is the layer count) — far enough to model long contexts while keeping per-layer cost linear.

The band mask

Position i attends only to keys in [i - W + 1, i] — itself and the W - 1 most recent past tokens. Equivalently, query i attends to key j iff i - j >= 0 AND i - j < W.

Build the mask by computing the index difference matrix:

i = jnp.arange(T)[:, None]               # column vector of query indices
j = jnp.arange(T)[None, :]               # row vector of key indices
diff = i - j                              # (T, T) — row i, col j has i - j
in_band = (diff >= 0) & (diff < window)  # boolean (T, T)

diff[i, j] = i - j:

  • i - j > W - 1 (too far in the past): NOT in band, mask out.
  • 0 <= i - j <= W - 1: in band, keep.
  • i - j < 0 (future): NOT in band — mask out (this is the causal part).

So sliding-window IS causal — the band is one-sided, looking only backward. The single condition (diff >= 0) & (diff < W) handles both causal and locality at once.

Apply BEFORE softmax

scores = jnp.where(in_band, scores, -1e9)
weights = jax.nn.softmax(scores, axis=-1)

Same -1e9 trick as the causal mask. Use a finite negative number, not -jnp.inf, to keep autodiff happy.

Worked sketch

class SlidingWindowMHA(nnx.Module):
    # __init__: same four nnx.Linear projections as plain MHA.

    def __call__(self, x, window):
        T, _ = x.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        i = jnp.arange(T)[:, None]
        j = jnp.arange(T)[None, :]
        diff = i - j
        in_band = (diff >= 0) & (diff < window)
        scores = jnp.where(in_band, scores, -1e9)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

The mask in_band has shape (T, T) — broadcasts over the head axis automatically, just like the causal mask.

Worked example

With T = 5, W = 3:

in_band:
[[1, 0, 0, 0, 0]    # pos 0 sees only pos 0
 [1, 1, 0, 0, 0]    # pos 1 sees pos {0, 1}
 [1, 1, 1, 0, 0]    # pos 2 sees pos {0, 1, 2}
 [0, 1, 1, 1, 0]    # pos 3 sees pos {1, 2, 3}
 [0, 0, 1, 1, 1]]   # pos 4 sees pos {2, 3, 4}

Past W = 3 tokens, the receptive field slides forward — position 3 no longer sees position 0.

What this achieves

Each token does W (not T) dot products in the attention layer. Cost goes from O(T^2) to O(T * W). For Mistral with W = 4096 and T = 32k, that’s an 8x reduction in attention cost (and 8x less KV-cache read per step at inference).

The KV cache also gets smaller in principle: only the last W tokens matter at each step. Mistral implements this as a circular buffer.

Common pitfalls

  • Sign of i - j. If you write diff = j - i you get the transpose — looking forward instead of back. Check by printing in_band for small T.
  • Off-by-one. i - j < W (strict) gives a window of size W including the current token. i - j <= W gives W + 1.
  • Dropping the causal half. abs(diff) < W would build a symmetric band — letting tokens peek forward. Sliding-window attention is causal AND windowed.
  • window arriving as float. Cast to int.

Problem

Write mha_sliding_window(seed, x, num_heads, d_model, window):

  1. Define SlidingWindowMHA(nnx.Module) with the standard four projections.
  2. __call__(x, window): compute scaled scores, build in_band from (diff >= 0) & (diff < window), fill the rest with -1e9, softmax, value matmul, out_proj.
  3. Cast every dimension arg to int. Return flattened.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model, window: ints (passed as floats).

Output: 1-D flattened.

Hints

flax nnx attention sliding-window transformers

Sign in to attempt this problem and view the solution.