medium primitives

NNX Label Smoothing

Why this matters

Standard cross-entropy with one-hot targets pushes the model to be 100% confident on the correct class. But “100% confident” is rarely correct on real data — it overfits, produces miscalibrated probabilities, and amplifies noise from mislabeled examples.

Label smoothing softens the target: instead of [1, 0, 0, 0] you use [1 - α + α/K, α/K, α/K, α/K] for K classes (sometimes written as [1 - α, 0, 0, 0] + α/K). The model learns to be confident, but not infinitely so.

Used in Inception-v3 (the original, 0.1), Transformer (0.1), most modern image classifiers, and many LLM fine-tuning recipes. It’s one of the cheapest regularizers — one extra arithmetic op.

This problem is intentionally framework-free: a label-smoothed cross-entropy is just JAX. nnx adds nothing here, and that’s the teaching point: not everything in the training stack needs to be framework-aware.

The math

For logits z of shape (N, K) and integer labels y of shape (N,):

target_n = (1 - α) * one_hot(y_n, K) + α / K
log_p_n  = log_softmax(z_n)
loss_n   = -sum_k target_n[k] * log_p_n[k]
loss     = mean(loss_n)

Equivalently:

loss_n = (1 - α) * (-log_p_n[y_n]) + α * (-mean_k log_p_n[k])

The two forms are mathematically identical; the first is what most code looks like.

Why log_softmax instead of log(softmax(...))?

log_softmax is numerically stable: under the hood it subtracts the max logit before exponentiating, avoiding overflow. The naive composition can produce log(0) = -inf when one logit dominates. Always use jax.nn.log_softmax, never jnp.log(jax.nn.softmax(...)).

Edge cases

  • α = 0: standard cross-entropy. Targets are exactly one-hot.
  • α = 1.0: target is uniform 1/K. Loss is just -mean log p_n[k] regardless of the label, which is essentially a uniform prior — not useful, but well-defined.
  • The harness passes labels_int as a float tensor; cast with labels_int.astype(jnp.int32) before jax.nn.one_hot.

Common pitfalls

  • Smoothing the loss instead of the targets. People sometimes add a -α * mean log p term to standard CE, which gives a different (worse) gradient. Smooth the TARGETS.
  • Using softmax directly without log. Then summing target * log(p) can produce log(0) = -inf. Use log_softmax and skip the explicit log.
  • Sum vs mean. Per-example loss sums over classes; the batch loss takes the mean over examples. Mixing these flips your effective gradient scale.
  • One-hot dtype mismatch. jax.nn.one_hot(int_labels, K) returns float32 by default — fine for our use, but it does implicitly cast.

Problem

Implement label_smoothing_loss(logits, labels_int, num_classes, alpha):

  1. Cast labels_int to int32, num_classes to int, alpha to float.
  2. one_hot = jax.nn.one_hot(labels, num_classes).
  3. smoothed = (1 - α) * one_hot + α / num_classes.
  4. log_probs = jax.nn.log_softmax(logits, axis=-1).
  5. per_example = -jnp.sum(smoothed * log_probs, axis=-1).
  6. Return jnp.array([float(jnp.mean(per_example))]) as 1-D (1,).

Pure JAX; no nnx model needed.

Inputs:

  • logits: 2-D (N, K).
  • labels_int: 1-D (N,), integer class IDs (passed as float, cast to int32).
  • num_classes: float (cast to int).
  • alpha: float in [0, 1]. 0 = standard CE.

Output: 1-D (1,)[mean_label_smoothed_cross_entropy].

Hints

flax nnx loss label-smoothing cross-entropy

Sign in to attempt this problem and view the solution.