We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Perplexity from Logits
Compute language-model perplexity from raw logits.
What is perplexity?
Perplexity is the standard metric for evaluating language models. It measures how “surprised” the model is by the actual tokens — lower is better.
Formula
Given a sequence of token predictions, compute the mean negative log-likelihood (NLL) across all positions, then exponentiate:
NLL = -mean_over_all_positions( log P(target_token | context) )
PPL = exp(NLL)
Geometric interpretation
Perplexity equals the average branching factor the model considers at each step. A perplexity of 4 means the model is as uncertain as if it were choosing uniformly among 4 options at every position.
-
PPL = 1→ perfect predictions (model assigns all probability to the correct token). -
PPL = V→ uniform distribution over a vocabulary of sizeV.
Implementation steps
-
Compute
log_softmaxover the vocab dimension (dim-1) oflogits. -
Gather the log-probability of each target token at every position.
-
PyTorch:
log_probs.gather(-1, targets.long().unsqueeze(-1)).squeeze(-1) -
JAX: flatten
(N, T)→(N*T,)and index witharange.
-
PyTorch:
-
Compute
NLL = -mean(target_log_probs). -
Return
exp(NLL)as a Pythonfloat.
When to use perplexity
- Comparing language models on held-out text (lower = better).
- Evaluating after fine-tuning or continued pre-training.
- Any pipeline where log-likelihood of sequences is the signal.
Inputs
-
logits: shape(N, T, V)— raw (pre-softmax) logits over a vocabulary of sizeV, at each ofTpositions, forNsequences. -
targets: shape(N, T)— integer target token ids (delivered as floats).
Output
Scalar float perplexity = exp(mean cross-entropy across all N×T positions).
Hints
Sign in to attempt this problem and view the solution.