We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Temperature Scaling
Why this matters
Temperature scaling is the most fundamental knob in LLM decoding. Every
modern language model exposes a temperature parameter that controls how
peaked or spread out the next-token distribution is. A temperature of 1.0
leaves the modelβs raw logits unchanged. Values below 1.0 sharpen the
distribution β the highest-probability tokens get even more of the mass β
while values above 1.0 flatten it, making rare tokens more likely. This
single scalar determines whether a model sounds confident and repetitive
or creative and diverse.
Understanding the mechanics β dividing logits BEFORE softmax β prevents a common mistake of applying temperature after softmax, which does not produce a valid probability distribution adjustment.
Worked mini-example
import jax, jax.numpy as jnp
logits = jnp.array([1.0, 2.0, 3.0])
# T=1 (unchanged)
jax.nn.softmax(logits / 1.0) # [0.090, 0.245, 0.665]
# T=0.5 (sharper β mass concentrates on index 2)
jax.nn.softmax(logits / 0.5) # [0.016, 0.117, 0.867]
# T=2.0 (flatter β more uniform)
jax.nn.softmax(logits / 2.0) # [0.186, 0.307, 0.506]
Common pitfalls
-
Divide BEFORE softmax:
softmax(logits / T)is correct. Dividing after (softmax(logits) / T) just scales a valid distribution down by T β the output no longer sums to 1 and is not a probability vector. -
T=0 is undefined (division by zero). For greedy decoding, use
jnp.argmax(logits)instead. - Temperature does not change argmax: the index of the maximum is invariant to positive scaling of logits.
Problem
Implement temperature_softmax(logits, temperature) that returns the
softmax of the temperature-scaled logits.
logits is a 1-D float array of shape (K,). temperature is a positive
float scalar. Return a 1-D float32 array of shape (K,) summing to 1.
One illustrative example (not from the test set):
-
temperature_softmax(jnp.array([0.0, 0.0]), 0.5)returns[0.5, 0.5]β uniform logits stay uniform at any temperature.
Hints
Sign in to attempt this problem and view the solution.