Implement the acceptance probability computation for Speculative Decoding from “Fast Inference from Transformers via Speculative Decoding” (Leviathan et al., 2023).
Speculative decoding uses a small draft model to propose tokens, then a large target model verifies them. The acceptance probability ensures the output distribution matches the target model exactly.
For each drafted token with:
The acceptance probability is: $$\alpha(x) = \min\left(1, \frac{p(x)}{q(x)}\right)$$
If rejected, sample from the residual distribution: $$p'(x) = \text{normalize}(\max(0, p(x) - q(x)))$$
Given:
draft_probs: shape (vocab_size,) — draft model distribution target_probs: shape (vocab_size,) — target model distribution drafted_token: integer — the token proposed by the draft model Output: A dict with:
"accept_prob": float — probability of accepting the drafted token "residual": shape (vocab_size,) — residual distribution if rejected