medium primitives

Top-k Logit Masking

Why this matters

Top-k sampling is a standard technique in LLM decoding that truncates the vocabulary to only the k highest-probability tokens before sampling. This prevents the model from ever sampling very unlikely tokens (which can produce incoherent text) while still allowing diversity among the most plausible continuations. The approach is used in GPT-2, GPT-3, and most production LLM deployments.

The canonical masking approach uses -inf so that masked tokens get exactly 0 probability after softmax. In this problem we use -1e9 instead โ€” a large finite negative value that is mathematically equivalent for masking (softmax(x - 1e9) โ‰ˆ 0 for any reasonable logit x) while remaining JSON-serializable for the test contract.

Worked mini-example

import jax, jax.numpy as jnp

logits = jnp.array([1.0, 5.0, 3.0, 2.0, 4.0])
k = 2
MASK_VALUE = -1e9

top_k_values, _ = jax.lax.top_k(logits, k)  # [5.0, 4.0]
threshold = top_k_values[-1]                  # 4.0

masked = jnp.where(logits >= threshold, logits, MASK_VALUE)
# โ†’ [-1e9, 5.0, -1e9, -1e9, 4.0]
# After softmax: [~0, 0.731, ~0, ~0, 0.269]

Common pitfalls

  • Use >= not >: ties at the threshold value must be kept. With >, you would drop tied entries and may keep fewer than k tokens.
  • Cast k to int: jax.lax.top_k requires a Python int, not a float.
  • Ties can expand the nucleus beyond k: if multiple logits equal the threshold, all of them are retained. The test cases reflect this (uniform logits case: all 3 are kept for k=2).
  • Do not sort before masking: the output must preserve the original index order of the logits.

Problem

Implement top_k_mask(logits, k) that returns the logits with all but the top-k values replaced by -1e9.

logits is a 1-D float array of shape (K,). k is a positive float (cast to int inside). Return a 1-D float32 array of the same shape (K,).

Note: real LLM code uses -inf for masking, but we use -1e9 here for JSON serialization in the test contract.

One illustrative example (not from the test set):

  • top_k_mask(jnp.array([3.0, 1.0, 2.0]), 1.0) returns [3.0, -1e9, -1e9] โ€” only the top-1 is kept.

Hints

jax top-k sampling

Sign in to attempt this problem and view the solution.