medium primitives

Numerically Stable logsumexp

Why this matters

log(sum(exp(x))) β€” the log-sum-exp operation β€” appears everywhere in probabilistic ML: softmax denominators, log-likelihood computations, MCMC importance weights, and more.

The naive form jnp.log(jnp.sum(jnp.exp(x))) has a critical flaw:

  • For large x (e.g., x = [100, 100, 100]): exp(100) overflows to inf, so the sum is inf, and the log is inf β€” wrong.
  • For very negative x (e.g., x = [-1000, -1000]): exp(-1000) underflows to 0.0, so the sum is 0, and log(0) is -inf β€” wrong.

The stable trick is: subtract the max before exponentiating.

logsumexp(x) = max(x) + log(sum(exp(x - max(x))))

Since x - max(x) is always ≀ 0, exp(x - max(x)) is in [0, 1] β€” no overflow. The max(x) term is added back algebraically, so the result is exact.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])

# Naive (works here, overflows for large x)
naive = jnp.log(jnp.sum(jnp.exp(x)))   # β‰ˆ 3.4076

# Stable (always works)
stable = jax.nn.logsumexp(x)            # β‰ˆ 3.4076

For x = [100, 100, 100], naive gives inf; jax.nn.logsumexp gives 101.099 (= 100 + log(3)).

Common pitfalls

  • Writing log(sum(exp(x))) directly β€” works for small values, silently overflows for large x (e.g., logit scale is often in the hundreds).
  • Using jnp.max(x) then forgetting to add it back β€” if you try to implement the trick manually, it’s easy to drop the + max at the end.
  • Confusing logsumexp with log(softmax) β€” they’re related but different; log_softmax(x) = x - logsumexp(x).

Problem

Implement stable_logsumexp(x) that computes log(sum(exp(x))) numerically stably.

  • x: 1-D jax array.
  • Returns: scalar.

Hints

jax logsumexp stable

Sign in to attempt this problem and view the solution.