We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Top-k Logit Masking
Why this matters
Top-k sampling is a standard technique in LLM decoding that truncates the vocabulary to only the k highest-probability tokens before sampling. This prevents the model from ever sampling very unlikely tokens (which can produce incoherent text) while still allowing diversity among the most plausible continuations. The approach is used in GPT-2, GPT-3, and most production LLM deployments.
The canonical masking approach uses -inf so that masked tokens get exactly
0 probability after softmax. In this problem we use -1e9 instead โ a large
finite negative value that is mathematically equivalent for masking
(softmax(x - 1e9) โ 0 for any reasonable logit x) while remaining
JSON-serializable for the test contract.
Worked mini-example
import jax, jax.numpy as jnp
logits = jnp.array([1.0, 5.0, 3.0, 2.0, 4.0])
k = 2
MASK_VALUE = -1e9
top_k_values, _ = jax.lax.top_k(logits, k) # [5.0, 4.0]
threshold = top_k_values[-1] # 4.0
masked = jnp.where(logits >= threshold, logits, MASK_VALUE)
# โ [-1e9, 5.0, -1e9, -1e9, 4.0]
# After softmax: [~0, 0.731, ~0, ~0, 0.269]
Common pitfalls
-
Use
>=not>: ties at the threshold value must be kept. With>, you would drop tied entries and may keep fewer than k tokens. -
Cast k to int:
jax.lax.top_krequires a Python int, not a float. - Ties can expand the nucleus beyond k: if multiple logits equal the threshold, all of them are retained. The test cases reflect this (uniform logits case: all 3 are kept for k=2).
- Do not sort before masking: the output must preserve the original index order of the logits.
Problem
Implement top_k_mask(logits, k) that returns the logits with all but the
top-k values replaced by -1e9.
logits is a 1-D float array of shape (K,). k is a positive float
(cast to int inside). Return a 1-D float32 array of the same shape (K,).
Note: real LLM code uses -inf for masking, but we use -1e9 here for
JSON serialization in the test contract.
One illustrative example (not from the test set):
-
top_k_mask(jnp.array([3.0, 1.0, 2.0]), 1.0)returns[3.0, -1e9, -1e9]โ only the top-1 is kept.
Hints
Sign in to attempt this problem and view the solution.