Implement the online softmax trick used in Flash Attention from “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (Dao et al., 2022).
Instead of computing the full attention matrix, process blocks of keys and maintain running statistics. Given queries Q and keys K (single head), compute the numerically-stable attention output using the online softmax algorithm:
For each block of keys K_j:
For simplicity, compute the final attention output (equivalent to standard attention) using Q, K, V with scaled dot-product.
Input: Q, K, V each of shape (seq_len, d_k)
Output: Attention output of shape (seq_len, d_k).