easy primitives

Perplexity from Logits

Compute language-model perplexity from raw logits.

What is perplexity?

Perplexity is the standard metric for evaluating language models. It measures how “surprised” the model is by the actual tokens — lower is better.

Formula

Given a sequence of token predictions, compute the mean negative log-likelihood (NLL) across all positions, then exponentiate:

NLL  = -mean_over_all_positions( log P(target_token | context) )
PPL  = exp(NLL)

Geometric interpretation

Perplexity equals the average branching factor the model considers at each step. A perplexity of 4 means the model is as uncertain as if it were choosing uniformly among 4 options at every position.

  • PPL = 1 → perfect predictions (model assigns all probability to the correct token).
  • PPL = V → uniform distribution over a vocabulary of size V.

Implementation steps

  1. Compute log_softmax over the vocab dimension (dim -1) of logits.
  2. Gather the log-probability of each target token at every position.
    • PyTorch: log_probs.gather(-1, targets.long().unsqueeze(-1)).squeeze(-1)
    • JAX: flatten (N, T)(N*T,) and index with arange.
  3. Compute NLL = -mean(target_log_probs).
  4. Return exp(NLL) as a Python float.

When to use perplexity

  • Comparing language models on held-out text (lower = better).
  • Evaluating after fine-tuning or continued pre-training.
  • Any pipeline where log-likelihood of sequences is the signal.

Inputs

  • logits: shape (N, T, V) — raw (pre-softmax) logits over a vocabulary of size V, at each of T positions, for N sequences.
  • targets: shape (N, T) — integer target token ids (delivered as floats).

Output

Scalar float perplexity = exp(mean cross-entropy across all N×T positions).

Hints

metrics lm perplexity

Sign in to attempt this problem and view the solution.