We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Stable log_softmax
Why this matters
Cross-entropy loss β the backbone of classifier training β requires
log(softmax(logits)). There are two ways to compute this, and one is
strictly better:
| Approach | Code | Problem |
|---|---|---|
| Two-step |
jnp.log(jax.nn.softmax(x)) |
Numerically fine, but slower and slightly less precise |
| Fused |
jax.nn.log_softmax(x) |
Faster, more stable, preferred |
The key identity is:
log_softmax(x)_i = x_i - logsumexp(x)
This is numerically stable because logsumexp(x) uses the max-subtraction
trick internally, and we never form the intermediate exp(x) or softmax(x)
values.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
lse = jax.nn.logsumexp(x) # β 3.4076
log_probs = jax.nn.log_softmax(x)
# = x - lse = [1 - 3.4076, 2 - 3.4076, 3 - 3.4076]
# β [-2.4076, -1.4076, -0.4076]
# sum(exp(log_probs)) == 1.0 β
For uniform inputs like [0, 0, 0, 0], log_softmax returns log(1/4)
= -log(4) β -1.3863 for every element.
Common pitfalls
-
Chaining
log(softmax(x))β works but is slower; uselog_softmaxdirectly in production code. - Forgetting that output is log-probabilities β donβt exponentiate unless you specifically want the softmax probabilities.
-
Confusing
log_softmaxwithlogsumexpβlogsumexpreturns a scalar;log_softmaxreturns a vector the same shape as input.
Problem
Implement stable_log_softmax(x) that computes log-probabilities stably.
-
x: 1-D jax array (logits). - Returns: 1-D array same shape β log probabilities (all β€ 0, sum to log(1) = 0 in exp-space).
Hints
jax
softmax
stable
Sign in to attempt this problem and view the solution.