We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Label Smoothing
Why this matters
Standard cross-entropy with one-hot targets pushes the model to be 100% confident on the correct class. But “100% confident” is rarely correct on real data — it overfits, produces miscalibrated probabilities, and amplifies noise from mislabeled examples.
Label smoothing softens the target: instead of [1, 0, 0, 0]
you use [1 - α + α/K, α/K, α/K, α/K] for K classes (sometimes
written as [1 - α, 0, 0, 0] + α/K). The model learns to be
confident, but not infinitely so.
Used in Inception-v3 (the original, 0.1), Transformer (0.1), most modern image classifiers, and many LLM fine-tuning recipes. It’s one of the cheapest regularizers — one extra arithmetic op.
This problem is intentionally framework-free: a label-smoothed cross-entropy is just JAX. nnx adds nothing here, and that’s the teaching point: not everything in the training stack needs to be framework-aware.
The math
For logits z of shape (N, K) and integer labels y of shape
(N,):
target_n = (1 - α) * one_hot(y_n, K) + α / K
log_p_n = log_softmax(z_n)
loss_n = -sum_k target_n[k] * log_p_n[k]
loss = mean(loss_n)
Equivalently:
loss_n = (1 - α) * (-log_p_n[y_n]) + α * (-mean_k log_p_n[k])
The two forms are mathematically identical; the first is what most code looks like.
Why log_softmax instead of log(softmax(...))?
log_softmax is numerically stable: under the hood it subtracts
the max logit before exponentiating, avoiding overflow. The naive
composition can produce log(0) = -inf when one logit dominates.
Always use jax.nn.log_softmax, never jnp.log(jax.nn.softmax(...)).
Edge cases
-
α = 0: standard cross-entropy. Targets are exactly one-hot. -
α = 1.0: target is uniform1/K. Loss is just-mean log p_n[k]regardless of the label, which is essentially a uniform prior — not useful, but well-defined. -
The harness passes
labels_intas a float tensor; cast withlabels_int.astype(jnp.int32)beforejax.nn.one_hot.
Common pitfalls
-
Smoothing the loss instead of the targets. People sometimes
add a
-α * mean log pterm to standard CE, which gives a different (worse) gradient. Smooth the TARGETS. -
Using
softmaxdirectly withoutlog. Then summingtarget * log(p)can producelog(0) = -inf. Uselog_softmaxand skip the explicit log. - Sum vs mean. Per-example loss sums over classes; the batch loss takes the mean over examples. Mixing these flips your effective gradient scale.
-
One-hot dtype mismatch.
jax.nn.one_hot(int_labels, K)returns float32 by default — fine for our use, but it does implicitly cast.
Problem
Implement label_smoothing_loss(logits, labels_int, num_classes, alpha):
-
Cast
labels_intto int32,num_classesto int,alphato float. -
one_hot = jax.nn.one_hot(labels, num_classes). -
smoothed = (1 - α) * one_hot + α / num_classes. -
log_probs = jax.nn.log_softmax(logits, axis=-1). -
per_example = -jnp.sum(smoothed * log_probs, axis=-1). -
Return
jnp.array([float(jnp.mean(per_example))])as 1-D(1,).
Pure JAX; no nnx model needed.
Inputs:
-
logits: 2-D(N, K). -
labels_int: 1-D(N,), integer class IDs (passed as float, cast to int32). -
num_classes: float (cast to int). -
alpha: float in[0, 1].0= standard CE.
Output: 1-D (1,) — [mean_label_smoothed_cross_entropy].
Hints
Sign in to attempt this problem and view the solution.