medium primitives

Label-Smoothed Cross-Entropy

Why this matters

Standard cross-entropy with one-hot targets pushes the network’s softmax output toward 1 on the true class and 0 on every other. To get exactly 1.0, the logit for the true class would have to grow toward +∞ (and logits for the rest toward -∞). The model becomes extremely confident — and extremely brittle.

Label smoothing mixes a tiny amount of uniform mass into the target. Instead of [0, 0, 1, 0] you train against [α/K, α/K, 1 - α + α/K, α/K]. The network can no longer drive any output to exactly 1 or 0; gradients stay informative; calibration improves.

Used in: ResNet/BiT image classification, the original Transformer paper (α = 0.1), Inception v3, almost every modern speech model.

The formula

Given logits Z ∈ R^{N×K} and integer labels y ∈ {0, ..., K-1}^N:

target_n = (1 - α) * one_hot(y_n, K) + α / K              # shape (K,)
loss_n   = -sum_k target_{n,k} * log_softmax(Z_n)_k       # scalar
loss     = mean_n loss_n

Note: this is exactly cross-entropy with a smoothed target — the same code path as standard CE, just a different target. When α = 0 you recover standard cross-entropy.

Why subtract α and add α/K to the true class?

Sum check: the smoothed target must sum to 1 (still a probability):

(1 - α) * 1 + α/K * K = (1 - α) + α = 1   ✓

Off-class slots get α / K. The true-class slot gets (1 - α) + α/K = 1 - α + α/K. As α → 0 you get 1; as α → 1 you get 1/K (uniform — useless training signal).

Worked example (K=4, α=0.1, true class = 2)

one_hot = [0,    0,    1,    0   ]
target  = [0.025, 0.025, 0.925, 0.025]   # sums to 1, peak at true class

Common pitfalls

  • Forgetting .astype(jnp.int32) on the labels before one_hot: jax.nn.one_hot requires integer indices.
  • Confusing softmax with log_softmax: you want log_softmax because the loss is -sum(target * log_p).
  • Using softmax_cross_entropy_with_integer_labels: that’s the non-smoothed variant — it can’t take a soft target. Use softmax_cross_entropy(logits, labels=target) style or roll it by hand as in the formula above.
  • Reducing on the wrong axis: per-example loss sums over classes (axis=-1); the batch mean averages over examples (axis=0).

Problem

Compute the label-smoothed cross-entropy:

target = (1 - alpha) * one_hot(labels, K) + alpha / K
loss = mean(-sum(target * log_softmax(logits, axis=-1), axis=-1))

Return the scalar loss wrapped in a 1-D (1,) array.

Inputs:

  • logits: 2-D (N, K).
  • labels_int: 1-D (N,) — integer class labels (delivered as floats; cast inside).
  • num_classes: scalar (cast to int) — K.
  • alpha: scalar — smoothing strength in [0, 1].

Output: 1-D (1,)[mean_loss].

Hints

jax loss label-smoothing regularization

Sign in to attempt this problem and view the solution.