medium primitives

Top-k Sampling

Implement top-k sampling β€” a stochastic decoding strategy that avoids low-probability tail tokens while still allowing diversity.

What is top-k sampling?

At each decoding step, instead of always picking the highest-logit token (greedy) or sampling from the full vocabulary, top-k sampling:

  1. Filters the vocabulary to the k tokens with the highest logits. All other logits are set to βˆ’βˆž, giving them zero probability after softmax.
  2. Renormalizes by applying softmax over the filtered logits.
  3. Samples one token from the resulting distribution.

This prevents the model from accidentally emitting incoherent tokens from the long tail of low-probability vocabulary items.

Algorithm

top_k_values, top_k_indices = topk(logits, k)
filtered_logits = fill(logits, -inf)
filtered_logits[top_k_indices] = top_k_values
probs = softmax(filtered_logits)
token = multinomial_sample(probs, seed=seed)
return token

Relationship to other strategies

  • Greedy decoding is equivalent to top-k with k=1 (always takes the argmax).
  • Top-p (nucleus) sampling is the dynamic version: instead of a fixed k, it keeps the smallest set of tokens whose cumulative probability exceeds p.
  • Temperature scaling can be combined with top-k: apply temperature first, then top-k filter, then softmax + sample.

PRNG note

PyTorch and JAX use different pseudo-random number generators. Given the same seed, they will produce different samples. The expected outputs for this problem are generated using PyTorch only. Your JAX solution will be tested for correctness of the algorithm (right distribution, right shape), not for exact value matching against PyTorch’s samples.

Inputs / Output

  • logits: 1-D tensor of shape (vocab,) β€” raw (unnormalized) scores.
  • k: int β€” number of top tokens to keep.
  • seed: int β€” random seed for reproducibility.

Output: a single integer (Python int) β€” the index of the sampled token.

Hints

llm decoding sampling

Sign in to attempt this problem and view the solution.