hard research

Flash Attention Score Computation

Implement the online softmax trick used in Flash Attention from “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (Dao et al., 2022).

Instead of computing the full attention matrix, process blocks of keys and maintain running statistics. Given queries Q and keys K (single head), compute the numerically-stable attention output using the online softmax algorithm:

For each block of keys K_j:

  1. Compute scores: $S_j = Q \cdot K_j^T / \sqrt{d_k}$
  2. Track running max: $m_{\text{new}} = \max(m_{\text{old}}, \max(S_j))$
  3. Update running sum of exp: $\ell_{\text{new}} = \ell_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + \sum e^{S_j - m_{\text{new}}}$
  4. Update output accumulator with rescaling

For simplicity, compute the final attention output (equivalent to standard attention) using Q, K, V with scaled dot-product.

Input: Q, K, V each of shape (seq_len, d_k) Output: Attention output of shape (seq_len, d_k).

Hints

flash-attention online-softmax dao-2022 attention efficiency
Detecting runtime...