medium primitives

Stable log_softmax

Why this matters

Cross-entropy loss β€” the backbone of classifier training β€” requires log(softmax(logits)). There are two ways to compute this, and one is strictly better:

Approach Code Problem
Two-step jnp.log(jax.nn.softmax(x)) Numerically fine, but slower and slightly less precise
Fused jax.nn.log_softmax(x) Faster, more stable, preferred

The key identity is:

log_softmax(x)_i = x_i - logsumexp(x)

This is numerically stable because logsumexp(x) uses the max-subtraction trick internally, and we never form the intermediate exp(x) or softmax(x) values.

Worked mini-example

import jax
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
lse = jax.nn.logsumexp(x)    # β‰ˆ 3.4076

log_probs = jax.nn.log_softmax(x)
# = x - lse = [1 - 3.4076, 2 - 3.4076, 3 - 3.4076]
# β‰ˆ [-2.4076, -1.4076, -0.4076]
# sum(exp(log_probs)) == 1.0  βœ“

For uniform inputs like [0, 0, 0, 0], log_softmax returns log(1/4) = -log(4) β‰ˆ -1.3863 for every element.

Common pitfalls

  • Chaining log(softmax(x)) β€” works but is slower; use log_softmax directly in production code.
  • Forgetting that output is log-probabilities β€” don’t exponentiate unless you specifically want the softmax probabilities.
  • Confusing log_softmax with logsumexp β€” logsumexp returns a scalar; log_softmax returns a vector the same shape as input.

Problem

Implement stable_log_softmax(x) that computes log-probabilities stably.

  • x: 1-D jax array (logits).
  • Returns: 1-D array same shape β€” log probabilities (all ≀ 0, sum to log(1) = 0 in exp-space).

Hints

jax softmax stable

Sign in to attempt this problem and view the solution.