hard end_to_end

Causal LM with KV Cache Generation

Implement autoregressive greedy generation with a KV cache — the technique that makes LLM inference fast enough to be practical.

Why KV cache?

In naive autoregressive generation, every new token requires a full forward pass over the entire sequence seen so far. That’s O(T²) per token, or O(T³) total — expensive when T grows.

The KV cache (key-value cache) eliminates the redundant work:

  • During the initial prompt pass, compute K and V tensors for every prompt token and save them, per attention block.
  • For each new token, only compute Q/K/V for that single token. Append the new K, V to the cached buffers, then run attention with new_Q against the full cached K and V.

This reduces the per-step work to O(T) (one attention query against a growing cache), bringing total generation cost to O(T²) instead of O(T³). Real LLMs (via vLLM, TGI, etc.) spend enormous engineering effort amortizing and paging this cache.

Algorithm

Step 1 — Initial prompt pass:

x = w_emb[prompt] + pos_embed[:T_prompt]   # (1, T_prompt, d_model)
causal_mask = tril(ones(T_prompt, T_prompt))
for each block i:
    Q, K, V = x @ w_q, x @ w_k, x @ w_v
    cached_K[i] = K   # shape (1, num_heads, T_prompt, d_head)
    cached_V[i] = V
    x = post_ln_mha(x, Q, K, V, causal_mask)
    x = post_ln_ffn(x)
logits = x @ w_head
first_new_token = argmax(logits[0, -1])

Step 2 — Per-token generation loop (for step in 0..max_new-2):

x_new = w_emb[token] + pos_embed[T_prompt + step]   # (1, 1, d_model)
for each block i:
    Q_new, K_new, V_new = x_new @ w_q, x_new @ w_k, x_new @ w_v
    cached_K[i] = cat([cached_K[i], K_new], dim=seq)
    cached_V[i] = cat([cached_V[i], V_new], dim=seq)
    # Attend new token to full cache — no causal mask needed here,
    # because the cache already holds only valid past tokens.
    scores = Q_new @ cached_K[i].T / sqrt(d_head)
    attn = softmax(scores)
    x_new = post_ln_mha(x_new, attn, cached_V[i])
    x_new = post_ln_ffn(x_new)
next_token = argmax(x_new @ w_head)[0, 0]

Step 3 — Return cat([prompt, generated]).

POST-LN Convention

Same as causal-lm-forward-pass:

attn_out = MHA(x)
x = LN(x + attn_out)   # residual THEN layer norm
ffn_out = FFN(x)
x = LN(x + ffn_out)

Weight Packing

blocks_weights has shape (num_blocks, 6, d_model, d_model). For block i, blocks_weights[i, 0..5] are: [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2]. d_ff = d_model.

Implementation Constraints

  • No model.generate(...), no nn.MultiheadAttention — implement KV cache by hand.
  • Manual layer norm: (x - mean) / sqrt(var + eps), eps=1e-5.
  • Manual GELU: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))).
  • Edge case: if max_new_tokens <= 0, return the prompt unchanged.

Output

Return a 1-D integer tensor of shape (T_prompt + max_new_tokens,) — the prompt tokens followed by every generated token (greedy argmax, no EOS halting in v1).

References

  • Radford et al., “Language Models are Unsupervised Multitask Learners” (GPT-2), OpenAI 2019.
  • Pope et al., “Efficiently Scaling Transformer Inference”, MLSys 2023.
  • vLLM: https://vllm.ai — production KV cache management at scale.

Hints

causal-lm kv-cache generation

Sign in to attempt this problem and view the solution.