hard end_to_end

Nucleus Sampling Generation Loop

Implement a full LLM-style generation loop combining temperature scaling and nucleus (top-p) sampling.

What is nucleus (top-p) sampling?

Introduced by Holtzman et al. 2019 (“The Curious Case of Neural Text Degeneration”), nucleus sampling dynamically selects the vocabulary to sample from at each step:

  1. Sort token probabilities in descending order.
  2. Cumulate probabilities from highest to lowest.
  3. Keep the smallest prefix whose cumulative probability ≥ top_p. Always keep at least the top-1 token.
  4. Zero out the rest, renormalize, then sample.

The key insight: instead of a fixed k (top-k), the nucleus adapts to the distribution shape. When the model is confident (one token dominates), the nucleus is tiny. When the model is uncertain (flat distribution), the nucleus is large. This avoids the pitfalls of both “too few tokens” and “sampling from the long tail”.

Temperature scaling

Before softmax, logits are divided by temperature:

  • temperature < 1.0 → distribution is sharper (more deterministic).
  • temperature > 1.0 → distribution is flatter (more diverse).
  • temperature = 1.0 → no change (standard softmax).
  • top_p = 1.0 → no nucleus filter; sample from the full distribution.

Per-step pipeline

for step in range(max_tokens):
    logits = logits_fn(seq)           # (vocab,)
    logits = logits / temperature     # temperature scaling
    probs  = softmax(logits)          # normalize
    # --- top-p filter ---
    sort probs descending → sorted_probs, sorted_idx
    cumsum = cumulative_sum(sorted_probs)
    keep the smallest prefix where cumsum >= top_p
    (always keep at least the first token)
    set non-kept probs to 0; renormalize
    # --- sample ---
    token = multinomial(filtered, seed=seed+step)
    seq.append(token)
    if token == eos_id: break
return tensor(seq)

PRNG note

PyTorch and JAX use different pseudo-random number generators. Given the same seed, they produce different samples. The expected outputs for this problem are generated with PyTorch only; JAX solutions are tested for algorithmic correctness but not for exact value matching.

Inputs / Output

  • logits_fn: callable (seq: list[int]) -> tensor shape (vocab,). In tests, passed as a string key substituted by the test harness.
  • prompt: 1-D tensor of starting token ids (T_prompt,). Cast to int.
  • max_tokens: int — max tokens to generate (prompt not counted).
  • eos_id: int — halt and include EOS in output when sampled.
  • temperature: float — divides logits before softmax.
  • top_p: float in (0, 1] — nucleus threshold.
  • seed: int — per-step seed is seed + step_index.

Output: 1-D tensor of token ids, shape (T_prompt + n_generated,).

Hints

llm decoding generation

Sign in to attempt this problem and view the solution.