We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Nucleus Sampling Generation Loop
Implement a full LLM-style generation loop combining temperature scaling and nucleus (top-p) sampling.
What is nucleus (top-p) sampling?
Introduced by Holtzman et al. 2019 (“The Curious Case of Neural Text Degeneration”), nucleus sampling dynamically selects the vocabulary to sample from at each step:
- Sort token probabilities in descending order.
- Cumulate probabilities from highest to lowest.
-
Keep the smallest prefix whose cumulative probability ≥
top_p. Always keep at least the top-1 token. - Zero out the rest, renormalize, then sample.
The key insight: instead of a fixed k (top-k), the nucleus adapts to
the distribution shape. When the model is confident (one token dominates),
the nucleus is tiny. When the model is uncertain (flat distribution), the
nucleus is large. This avoids the pitfalls of both “too few tokens” and
“sampling from the long tail”.
Temperature scaling
Before softmax, logits are divided by temperature:
-
temperature < 1.0→ distribution is sharper (more deterministic). -
temperature > 1.0→ distribution is flatter (more diverse). -
temperature = 1.0→ no change (standard softmax). -
top_p = 1.0→ no nucleus filter; sample from the full distribution.
Per-step pipeline
for step in range(max_tokens):
logits = logits_fn(seq) # (vocab,)
logits = logits / temperature # temperature scaling
probs = softmax(logits) # normalize
# --- top-p filter ---
sort probs descending → sorted_probs, sorted_idx
cumsum = cumulative_sum(sorted_probs)
keep the smallest prefix where cumsum >= top_p
(always keep at least the first token)
set non-kept probs to 0; renormalize
# --- sample ---
token = multinomial(filtered, seed=seed+step)
seq.append(token)
if token == eos_id: break
return tensor(seq)
PRNG note
PyTorch and JAX use different pseudo-random number generators. Given
the same seed, they produce different samples. The expected outputs for
this problem are generated with PyTorch only; JAX solutions are tested
for algorithmic correctness but not for exact value matching.
Inputs / Output
-
logits_fn: callable(seq: list[int]) -> tensor shape (vocab,). In tests, passed as a string key substituted by the test harness. -
prompt: 1-D tensor of starting token ids(T_prompt,). Cast to int. -
max_tokens: int — max tokens to generate (prompt not counted). -
eos_id: int — halt and include EOS in output when sampled. -
temperature: float — divides logits before softmax. -
top_p: float in (0, 1] — nucleus threshold. -
seed: int — per-step seed isseed + step_index.
Output: 1-D tensor of token ids, shape (T_prompt + n_generated,).
Hints
Sign in to attempt this problem and view the solution.