Implement KV cache management for autoregressive transformer decoding.
During autoregressive generation, each new token only needs to attend to all previous tokens. The KV cache stores previously computed key and value projections to avoid recomputation.
Given:
cached_K: shape (cache_len, d_k) — previously cached keys cached_V: shape (cache_len, d_k) — previously cached values new_q: shape (1, d_k) — query for the new token new_k: shape (1, d_k) — key for the new token new_v: shape (1, d_k) — value for the new token Steps:
Output: A dict with:
"output": shape (1, d_k) — the attention output for the new token "updated_K": shape (cache_len+1, d_k) — updated key cache "updated_V": shape (cache_len+1, d_k) — updated value cache