We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Causal LM with KV Cache Generation
Implement autoregressive greedy generation with a KV cache — the technique that makes LLM inference fast enough to be practical.
Why KV cache?
In naive autoregressive generation, every new token requires a full forward pass over the entire sequence seen so far. That’s O(T²) per token, or O(T³) total — expensive when T grows.
The KV cache (key-value cache) eliminates the redundant work:
- During the initial prompt pass, compute K and V tensors for every prompt token and save them, per attention block.
- For each new token, only compute Q/K/V for that single token. Append the new K, V to the cached buffers, then run attention with new_Q against the full cached K and V.
This reduces the per-step work to O(T) (one attention query against a growing cache), bringing total generation cost to O(T²) instead of O(T³). Real LLMs (via vLLM, TGI, etc.) spend enormous engineering effort amortizing and paging this cache.
Algorithm
Step 1 — Initial prompt pass:
x = w_emb[prompt] + pos_embed[:T_prompt] # (1, T_prompt, d_model)
causal_mask = tril(ones(T_prompt, T_prompt))
for each block i:
Q, K, V = x @ w_q, x @ w_k, x @ w_v
cached_K[i] = K # shape (1, num_heads, T_prompt, d_head)
cached_V[i] = V
x = post_ln_mha(x, Q, K, V, causal_mask)
x = post_ln_ffn(x)
logits = x @ w_head
first_new_token = argmax(logits[0, -1])
Step 2 — Per-token generation loop (for step in 0..max_new-2):
x_new = w_emb[token] + pos_embed[T_prompt + step] # (1, 1, d_model)
for each block i:
Q_new, K_new, V_new = x_new @ w_q, x_new @ w_k, x_new @ w_v
cached_K[i] = cat([cached_K[i], K_new], dim=seq)
cached_V[i] = cat([cached_V[i], V_new], dim=seq)
# Attend new token to full cache — no causal mask needed here,
# because the cache already holds only valid past tokens.
scores = Q_new @ cached_K[i].T / sqrt(d_head)
attn = softmax(scores)
x_new = post_ln_mha(x_new, attn, cached_V[i])
x_new = post_ln_ffn(x_new)
next_token = argmax(x_new @ w_head)[0, 0]
Step 3 — Return cat([prompt, generated]).
POST-LN Convention
Same as causal-lm-forward-pass:
attn_out = MHA(x)
x = LN(x + attn_out) # residual THEN layer norm
ffn_out = FFN(x)
x = LN(x + ffn_out)
Weight Packing
blocks_weights has shape (num_blocks, 6, d_model, d_model).
For block i, blocks_weights[i, 0..5] are:
[w_q, w_k, w_v, w_o, w_mlp1, w_mlp2]. d_ff = d_model.
Implementation Constraints
-
No
model.generate(...), nonn.MultiheadAttention— implement KV cache by hand. -
Manual layer norm:
(x - mean) / sqrt(var + eps), eps=1e-5. -
Manual GELU:
0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))). -
Edge case: if
max_new_tokens <= 0, return the prompt unchanged.
Output
Return a 1-D integer tensor of shape (T_prompt + max_new_tokens,) —
the prompt tokens followed by every generated token (greedy argmax,
no EOS halting in v1).
References
- Radford et al., “Language Models are Unsupervised Multitask Learners” (GPT-2), OpenAI 2019.
- Pope et al., “Efficiently Scaling Transformer Inference”, MLSys 2023.
- vLLM: https://vllm.ai — production KV cache management at scale.
Hints
Sign in to attempt this problem and view the solution.