hard research

KV Cache for Autoregressive Decoding

Implement KV cache management for autoregressive transformer decoding.

During autoregressive generation, each new token only needs to attend to all previous tokens. The KV cache stores previously computed key and value projections to avoid recomputation.

Given:

  • cached_K: shape (cache_len, d_k) — previously cached keys
  • cached_V: shape (cache_len, d_k) — previously cached values
  • new_q: shape (1, d_k) — query for the new token
  • new_k: shape (1, d_k) — key for the new token
  • new_v: shape (1, d_k) — value for the new token

Steps:

  1. Append new_k to cached_K, and new_v to cached_V
  2. Compute attention: output = softmax(new_q @ full_K^T / sqrt(d_k)) @ full_V

Output: A dict with:

  • "output": shape (1, d_k) — the attention output for the new token
  • "updated_K": shape (cache_len+1, d_k) — updated key cache
  • "updated_V": shape (cache_len+1, d_k) — updated value cache

Hints

kv-cache autoregressive decoding inference transformer
Detecting runtime...