We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Nucleus (Top-p) Masking
Why this matters
Nucleus sampling (Holtzman et al., 2020 — “The Curious Case of Neural Text Degeneration”) improves on top-k by adapting to the distribution: instead of always keeping a fixed k tokens, it keeps the smallest set of tokens whose combined probability mass covers at least p of the distribution. This ensures the nucleus is always semantically coherent regardless of how peaked or spread the distribution is. Nucleus sampling with p=0.9 is the default in many production LLM deployments and is universally preferred over top-k in modern practice.
Worked mini-example
import jax, jax.numpy as jnp
logits = jnp.array([1.0, 2.0, 3.0, 4.0])
p = 0.9
MASK_VALUE = -1e9
# Sort descending
sorted_logits = jnp.sort(logits)[::-1] # [4.0, 3.0, 2.0, 1.0]
sorted_probs = jax.nn.softmax(sorted_logits)
# ≈ [0.644, 0.237, 0.087, 0.032]
cumulative_probs = jnp.cumsum(sorted_probs)
# ≈ [0.644, 0.881, 0.968, 1.0]
# keep[i] = True if adding token i stays under p BEFORE including it
keep = cumulative_probs - sorted_probs < p
# ≈ [True, True, False, False] (0.644 < 0.9, 0.881 < 0.9, 0.968 >= 0.9)
# Always keep top-1
keep = keep | (jnp.arange(4) == 0)
# Smallest kept logit = threshold
threshold = jnp.where(keep, sorted_logits, jnp.inf).min() # 3.0
masked = jnp.where(logits >= threshold, logits, MASK_VALUE)
# → [-1e9, 2.0, 3.0, 4.0]
Common pitfalls
-
Sort DESCENDING:
jnp.sort(logits)[::-1]not ascending. -
Edge case — top-1 exceeds p: always keep at least the top-1 token,
even if its probability alone exceeds p. Guard with
| (arange == 0). -
Ties: the threshold comparison (
>=) retains all logits tied at the threshold value in the original (unsorted) array. -
Cumsum logic: keep token i if cumulative BEFORE including i is < p,
i.e.
cumsum[i] - prob[i] < p, notcumsum[i] < p. -
Real LLM code uses
-inf: we use-1e9here for JSON serialization in the test contract. Both are mathematically equivalent for masking.
Problem
Implement nucleus_top_p_mask(logits, p) that returns the logits with all
tokens outside the nucleus replaced by -1e9.
logits is a 1-D float array of shape (K,). p is a float in (0, 1].
Return a 1-D float32 array of the same shape (K,).
Hints
Sign in to attempt this problem and view the solution.