medium primitives

Triplet Loss

Implement triplet loss, a metric-learning objective that trains embeddings so that similar examples cluster together and dissimilar examples are pushed apart.

Given a batch of triplets โ€” an anchor $a$, a positive example $p$ (same class as the anchor), and a negative example $n$ (different class) โ€” the loss is:

$$\mathcal{L} = \frac{1}{N}\sum_{i=1}^{N} \max\!\left(0,\; \|a_i - p_i\|^2 - \|a_i - n_i\|^2 + \alpha\right)$$

where $\alpha$ is the margin hyperparameter. The squared L2 distances measure how far apart the embeddings are; the margin enforces a minimum gap between $\|a-n\|^2$ and $\|a-p\|^2$.

Intuition: The loss is zero when the negative is already far enough away from the anchor relative to the positive (by at least $\alpha$). Otherwise, it penalises the model proportionally to how much the geometry is violated.

History: Originally proposed by Weinberger & Saul (2006) as Large Margin Nearest Neighbour (LMNN). Popularised for face recognition by Schroff et al. FaceNet (2015). Now ubiquitous in contrastive/metric learning: face verification, image retrieval, sentence embeddings, and speaker ID.

Inputs:

  • anchor: tensor of shape (N, d)
  • positive: tensor of shape (N, d) โ€” same class as anchor
  • negative: tensor of shape (N, d) โ€” different class from anchor
  • margin: Python float scalar

Output: Scalar โ€” mean loss over the batch.

Hints

loss embedding metric-learning

Sign in to attempt this problem and view the solution.