We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Top-k Sampling
Implement top-k sampling β a stochastic decoding strategy that avoids low-probability tail tokens while still allowing diversity.
What is top-k sampling?
At each decoding step, instead of always picking the highest-logit token (greedy) or sampling from the full vocabulary, top-k sampling:
-
Filters the vocabulary to the
ktokens with the highest logits. All other logits are set to ββ, giving them zero probability after softmax. - Renormalizes by applying softmax over the filtered logits.
- Samples one token from the resulting distribution.
This prevents the model from accidentally emitting incoherent tokens from the long tail of low-probability vocabulary items.
Algorithm
top_k_values, top_k_indices = topk(logits, k)
filtered_logits = fill(logits, -inf)
filtered_logits[top_k_indices] = top_k_values
probs = softmax(filtered_logits)
token = multinomial_sample(probs, seed=seed)
return token
Relationship to other strategies
- Greedy decoding is equivalent to top-k with k=1 (always takes the argmax).
- Top-p (nucleus) sampling is the dynamic version: instead of a fixed k, it keeps the smallest set of tokens whose cumulative probability exceeds p.
- Temperature scaling can be combined with top-k: apply temperature first, then top-k filter, then softmax + sample.
PRNG note
PyTorch and JAX use different pseudo-random number generators. Given the same
seed, they will produce different samples. The expected outputs for this
problem are generated using PyTorch only. Your JAX solution will be tested
for correctness of the algorithm (right distribution, right shape), not for
exact value matching against PyTorchβs samples.
Inputs / Output
-
logits: 1-D tensor of shape(vocab,)β raw (unnormalized) scores. -
k: int β number of top tokens to keep. -
seed: int β random seed for reproducibility.
Output: a single integer (Python int) β the index of the sampled token.
Hints
Sign in to attempt this problem and view the solution.