medium end_to_end

Train Multiclass Classifier with Early Stopping

Train a multiclass linear classifier (softmax + cross-entropy) with early stopping based on validation loss.

The model

Given a feature matrix x_train of shape (N, d) and integer class labels y_train of shape (N,) with C classes, the model computes class probabilities via softmax:

$$p_i = \text{softmax}(x_i^\top W)_c = \frac{e^{x_i^\top w_c}}{\sum_{c'} e^{x_i^\top w_{c'}}}$$

where W has shape (d, C).

Gradient of mean cross-entropy

Build a one-hot matrix from y_train, then the gradient w.r.t. W is:

$$\nabla_W \mathcal{L} = \frac{1}{N} X^\top (P - \text{OneHot}(y))$$

where P is the (N, C) softmax probability matrix.

Early stopping

Track the best validation loss seen so far. After each weight update, compute the validation cross-entropy. If it improves, reset the no-improvement counter and save the current weights as best_w. Otherwise, increment the counter; when it reaches patience, stop training and return best_w.

Algorithm

best_loss = inf; best_w = w0; no_improve = 0
for epoch in range(max_epochs):
    # Train step
    logits = x_train @ w           # (N_train, C)
    probs  = softmax(logits)        # stable: subtract row-wise max first
    one_hot[i, y_train[i]] = 1
    grad = x_train.T @ (probs - one_hot) / N_train
    w = w - lr * grad

    # Val step
    val_probs = softmax(x_val @ w)
    val_loss  = -mean(log(val_probs[i, y_val[i]]))

    if val_loss < best_loss:
        best_loss = val_loss; best_w = w.copy(); no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience: break

return best_w

Note: Use a numerically stable softmax β€” subtract the row-wise maximum before exponentiating.

Inputs:

  • x_train: shape (N_train, d) β€” training features.
  • y_train: shape (N_train,) β€” integer class labels in [0, C), delivered as floats.
  • x_val: shape (N_val, d) β€” validation features.
  • y_val: shape (N_val,) β€” integer class labels in [0, C), delivered as floats.
  • w0: shape (d, C) β€” initial weights (deterministic).
  • lr: float β€” learning rate.
  • max_epochs: int β€” upper bound on epochs.
  • patience: int β€” stop if val loss doesn’t improve for this many consecutive epochs.

Output: best_w of shape (d, C) β€” the weights at the epoch with the lowest validation loss.

Edge cases: lr=0 β†’ val loss never improves β†’ patience triggers; max_epochs=0 β†’ loop never runs β†’ return w0.

Hints

classification training early-stopping

Sign in to attempt this problem and view the solution.