hard primitives

Nucleus (Top-p) Masking

Why this matters

Nucleus sampling (Holtzman et al., 2020 — “The Curious Case of Neural Text Degeneration”) improves on top-k by adapting to the distribution: instead of always keeping a fixed k tokens, it keeps the smallest set of tokens whose combined probability mass covers at least p of the distribution. This ensures the nucleus is always semantically coherent regardless of how peaked or spread the distribution is. Nucleus sampling with p=0.9 is the default in many production LLM deployments and is universally preferred over top-k in modern practice.

Worked mini-example

import jax, jax.numpy as jnp

logits = jnp.array([1.0, 2.0, 3.0, 4.0])
p = 0.9
MASK_VALUE = -1e9

# Sort descending
sorted_logits = jnp.sort(logits)[::-1]    # [4.0, 3.0, 2.0, 1.0]
sorted_probs = jax.nn.softmax(sorted_logits)
# ≈ [0.644, 0.237, 0.087, 0.032]

cumulative_probs = jnp.cumsum(sorted_probs)
# ≈ [0.644, 0.881, 0.968, 1.0]

# keep[i] = True if adding token i stays under p BEFORE including it
keep = cumulative_probs - sorted_probs < p
# ≈ [True, True, False, False]  (0.644 < 0.9, 0.881 < 0.9, 0.968 >= 0.9)

# Always keep top-1
keep = keep | (jnp.arange(4) == 0)

# Smallest kept logit = threshold
threshold = jnp.where(keep, sorted_logits, jnp.inf).min()  # 3.0

masked = jnp.where(logits >= threshold, logits, MASK_VALUE)
# → [-1e9, 2.0, 3.0, 4.0]

Common pitfalls

  • Sort DESCENDING: jnp.sort(logits)[::-1] not ascending.
  • Edge case — top-1 exceeds p: always keep at least the top-1 token, even if its probability alone exceeds p. Guard with | (arange == 0).
  • Ties: the threshold comparison (>=) retains all logits tied at the threshold value in the original (unsorted) array.
  • Cumsum logic: keep token i if cumulative BEFORE including i is < p, i.e. cumsum[i] - prob[i] < p, not cumsum[i] < p.
  • Real LLM code uses -inf: we use -1e9 here for JSON serialization in the test contract. Both are mathematically equivalent for masking.

Problem

Implement nucleus_top_p_mask(logits, p) that returns the logits with all tokens outside the nucleus replaced by -1e9.

logits is a 1-D float array of shape (K,). p is a float in (0, 1]. Return a 1-D float32 array of the same shape (K,).

Hints

jax top-p nucleus

Sign in to attempt this problem and view the solution.