We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Numerically Stable logsumexp
Why this matters
log(sum(exp(x))) β the log-sum-exp operation β appears everywhere in
probabilistic ML: softmax denominators, log-likelihood computations, MCMC
importance weights, and more.
The naive form jnp.log(jnp.sum(jnp.exp(x))) has a critical flaw:
-
For large x (e.g., x = [100, 100, 100]):
exp(100)overflows toinf, so the sum isinf, and the log isinfβ wrong. -
For very negative x (e.g., x = [-1000, -1000]):
exp(-1000)underflows to0.0, so the sum is0, andlog(0)is-infβ wrong.
The stable trick is: subtract the max before exponentiating.
logsumexp(x) = max(x) + log(sum(exp(x - max(x))))
Since x - max(x) is always β€ 0, exp(x - max(x)) is in [0, 1] β no overflow.
The max(x) term is added back algebraically, so the result is exact.
Worked mini-example
import jax
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
# Naive (works here, overflows for large x)
naive = jnp.log(jnp.sum(jnp.exp(x))) # β 3.4076
# Stable (always works)
stable = jax.nn.logsumexp(x) # β 3.4076
For x = [100, 100, 100], naive gives inf; jax.nn.logsumexp gives
101.099 (= 100 + log(3)).
Common pitfalls
-
Writing
log(sum(exp(x)))directly β works for small values, silently overflows for large x (e.g., logit scale is often in the hundreds). -
Using
jnp.max(x)then forgetting to add it back β if you try to implement the trick manually, itβs easy to drop the+ maxat the end. -
Confusing logsumexp with log(softmax) β theyβre related but different;
log_softmax(x) = x - logsumexp(x).
Problem
Implement stable_logsumexp(x) that computes log(sum(exp(x))) numerically stably.
-
x: 1-D jax array. - Returns: scalar.
Hints
Sign in to attempt this problem and view the solution.