easy primitives

Accuracy โ€“ Multiclass

Implement multiclass classification accuracy.

What is accuracy?

Accuracy is the fraction of examples whose predicted class matches the true class label. For a multiclass model that outputs a score vector of shape (C,) per example, the predicted class is the index with the highest score โ€” the argmax.

accuracy = (number of correct predictions) / (total examples)

Step by step

  1. Take argmax over the class dimension (dim=-1 in PyTorch / axis=-1 in JAX) to get a predicted class index for each example.
  2. Compare element-wise to labels โ€” this yields a boolean tensor.
  3. Cast to float and take the mean. Return as a Python scalar.

When to use accuracy vs F1

Accuracy is the right default for balanced class distributions: if each class appears roughly equally often, a random classifier scores 1/C and gains from being right. When classes are imbalanced โ€” e.g. 99 % negatives โ€” accuracy is misleading (a model that always predicts the majority class is 99 % accurate but useless). In that case prefer macro-F1 or per-class precision/recall.

Reference: sklearn.metrics.accuracy_score computes the same quantity.

Inputs

  • predictions: shape (N, C) โ€” class scores (logits or probabilities). N examples, C classes.
  • labels: shape (N,) โ€” integer class indices, delivered as floats by the test harness.

Output

Scalar float โ€” fraction of examples whose argmax equals the label.

Hints

metrics classification accuracy

Sign in to attempt this problem and view the solution.