easy primitives

Temperature Scaling

Why this matters

Temperature scaling is the most fundamental knob in LLM decoding. Every modern language model exposes a temperature parameter that controls how peaked or spread out the next-token distribution is. A temperature of 1.0 leaves the model’s raw logits unchanged. Values below 1.0 sharpen the distribution β€” the highest-probability tokens get even more of the mass β€” while values above 1.0 flatten it, making rare tokens more likely. This single scalar determines whether a model sounds confident and repetitive or creative and diverse.

Understanding the mechanics β€” dividing logits BEFORE softmax β€” prevents a common mistake of applying temperature after softmax, which does not produce a valid probability distribution adjustment.

Worked mini-example

import jax, jax.numpy as jnp

logits = jnp.array([1.0, 2.0, 3.0])

# T=1 (unchanged)
jax.nn.softmax(logits / 1.0)   # [0.090, 0.245, 0.665]

# T=0.5 (sharper β€” mass concentrates on index 2)
jax.nn.softmax(logits / 0.5)   # [0.016, 0.117, 0.867]

# T=2.0 (flatter β€” more uniform)
jax.nn.softmax(logits / 2.0)   # [0.186, 0.307, 0.506]

Common pitfalls

  • Divide BEFORE softmax: softmax(logits / T) is correct. Dividing after (softmax(logits) / T) just scales a valid distribution down by T β€” the output no longer sums to 1 and is not a probability vector.
  • T=0 is undefined (division by zero). For greedy decoding, use jnp.argmax(logits) instead.
  • Temperature does not change argmax: the index of the maximum is invariant to positive scaling of logits.

Problem

Implement temperature_softmax(logits, temperature) that returns the softmax of the temperature-scaled logits.

logits is a 1-D float array of shape (K,). temperature is a positive float scalar. Return a 1-D float32 array of shape (K,) summing to 1.

One illustrative example (not from the test set):

  • temperature_softmax(jnp.array([0.0, 0.0]), 0.5) returns [0.5, 0.5] β€” uniform logits stay uniform at any temperature.

Hints

jax temperature sampling

Sign in to attempt this problem and view the solution.