We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Label-Smoothed Cross-Entropy
Why this matters
Standard cross-entropy with one-hot targets pushes the network’s
softmax output toward 1 on the true class and 0 on every other.
To get exactly 1.0, the logit for the true class would have to
grow toward +∞ (and logits for the rest toward -∞). The model
becomes extremely confident — and extremely brittle.
Label smoothing mixes a tiny amount of uniform mass into the
target. Instead of [0, 0, 1, 0] you train against
[α/K, α/K, 1 - α + α/K, α/K]. The network can no longer drive any
output to exactly 1 or 0; gradients stay informative; calibration
improves.
Used in: ResNet/BiT image classification, the original Transformer
paper (α = 0.1), Inception v3, almost every modern speech model.
The formula
Given logits Z ∈ R^{N×K} and integer labels y ∈ {0, ..., K-1}^N:
target_n = (1 - α) * one_hot(y_n, K) + α / K # shape (K,)
loss_n = -sum_k target_{n,k} * log_softmax(Z_n)_k # scalar
loss = mean_n loss_n
Note: this is exactly cross-entropy with a smoothed target — the
same code path as standard CE, just a different target. When
α = 0 you recover standard cross-entropy.
Why subtract α and add α/K to the true class?
Sum check: the smoothed target must sum to 1 (still a probability):
(1 - α) * 1 + α/K * K = (1 - α) + α = 1 ✓
Off-class slots get α / K. The true-class slot gets
(1 - α) + α/K = 1 - α + α/K. As α → 0 you get 1; as α → 1 you
get 1/K (uniform — useless training signal).
Worked example (K=4, α=0.1, true class = 2)
one_hot = [0, 0, 1, 0 ]
target = [0.025, 0.025, 0.925, 0.025] # sums to 1, peak at true class
Common pitfalls
-
Forgetting
.astype(jnp.int32)on the labels beforeone_hot:jax.nn.one_hotrequires integer indices. -
Confusing
softmaxwithlog_softmax: you wantlog_softmaxbecause the loss is-sum(target * log_p). -
Using
softmax_cross_entropy_with_integer_labels: that’s the non-smoothed variant — it can’t take a soft target. Usesoftmax_cross_entropy(logits, labels=target)style or roll it by hand as in the formula above. -
Reducing on the wrong axis: per-example loss sums over classes
(
axis=-1); the batch mean averages over examples (axis=0).
Problem
Compute the label-smoothed cross-entropy:
target = (1 - alpha) * one_hot(labels, K) + alpha / K
loss = mean(-sum(target * log_softmax(logits, axis=-1), axis=-1))
Return the scalar loss wrapped in a 1-D (1,) array.
Inputs:
-
logits: 2-D(N, K). -
labels_int: 1-D(N,)— integer class labels (delivered as floats; cast inside). -
num_classes: scalar (cast to int) —K. -
alpha: scalar — smoothing strength in[0, 1].
Output: 1-D (1,) — [mean_loss].
Hints
Sign in to attempt this problem and view the solution.