medium primitives

SVD Singular Values

Why this matters

The Singular Value Decomposition (SVD) factors A = U @ diag(s) @ V.T, where U and V have orthonormal columns and s contains the singular values (non-negative, sorted descending). The singular values reveal:

  • Rank β€” count of nonzero singular values = rank(A).
  • Condition number β€” s[0] / s[-1]; large values mean ill-conditioned.
  • Low-rank approximation β€” keep only the top-k singular values for a rank-k approximation of A (Eckart–Young theorem).

Key applications:

  • PCA β€” SVD of the centered data matrix X gives principal components.
  • Latent semantic analysis β€” document-term matrix compression.
  • Matrix completion / recommender systems β€” low-rank structure.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[3.0, 0.0],
               [0.0, 4.0]])

s = jnp.linalg.svd(A, compute_uv=False)
# s = [4.0, 3.0]  β€” sorted descending

Common pitfalls

  • Output is sorted DESCENDING β€” s[0] is the largest singular value.
  • compute_uv=False β€” returns just the singular values (cheaper); omitting it returns (U, s, Vh) where Vh = V.T.
  • Count β€” for an (m, n) matrix, len(s) = min(m, n).

Problem

Implement svd_singular_values(A) that returns the singular values of A as a 1-D array sorted in descending order.

  • A: 2-D jax array (m, n).
  • Returns: 1-D array (min(m, n),) β€” singular values, sorted descending.

Hints

jax linalg svd

Sign in to attempt this problem and view the solution.