We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train Multiclass Classifier with Early Stopping
Train a multiclass linear classifier (softmax + cross-entropy) with early stopping based on validation loss.
The model
Given a feature matrix x_train of shape (N, d) and integer class labels
y_train of shape (N,) with C classes, the model computes class
probabilities via softmax:
$$p_i = \text{softmax}(x_i^\top W)_c = \frac{e^{x_i^\top w_c}}{\sum_{c'} e^{x_i^\top w_{c'}}}$$
where W has shape (d, C).
Gradient of mean cross-entropy
Build a one-hot matrix from y_train, then the gradient w.r.t. W is:
$$\nabla_W \mathcal{L} = \frac{1}{N} X^\top (P - \text{OneHot}(y))$$
where P is the (N, C) softmax probability matrix.
Early stopping
Track the best validation loss seen so far. After each weight update, compute
the validation cross-entropy. If it improves, reset the no-improvement counter
and save the current weights as best_w. Otherwise, increment the counter;
when it reaches patience, stop training and return best_w.
Algorithm
best_loss = inf; best_w = w0; no_improve = 0
for epoch in range(max_epochs):
# Train step
logits = x_train @ w # (N_train, C)
probs = softmax(logits) # stable: subtract row-wise max first
one_hot[i, y_train[i]] = 1
grad = x_train.T @ (probs - one_hot) / N_train
w = w - lr * grad
# Val step
val_probs = softmax(x_val @ w)
val_loss = -mean(log(val_probs[i, y_val[i]]))
if val_loss < best_loss:
best_loss = val_loss; best_w = w.copy(); no_improve = 0
else:
no_improve += 1
if no_improve >= patience: break
return best_w
Note: Use a numerically stable softmax β subtract the row-wise maximum before exponentiating.
Inputs:
-
x_train: shape(N_train, d)β training features. -
y_train: shape(N_train,)β integer class labels in[0, C), delivered as floats. -
x_val: shape(N_val, d)β validation features. -
y_val: shape(N_val,)β integer class labels in[0, C), delivered as floats. -
w0: shape(d, C)β initial weights (deterministic). -
lr: float β learning rate. -
max_epochs: int β upper bound on epochs. -
patience: int β stop if val loss doesnβt improve for this many consecutive epochs.
Output: best_w of shape (d, C) β the weights at the epoch with the lowest validation loss.
Edge cases: lr=0 β val loss never improves β patience triggers; max_epochs=0 β loop never runs β return w0.
Hints
Sign in to attempt this problem and view the solution.