hard primitives

eigh: Symmetric Eigendecomp

Why this matters

jnp.linalg.eigh computes the eigendecomposition of a symmetric (or Hermitian) matrix, returning real eigenvalues and orthonormal eigenvectors. It exploits symmetry for ~2ร— greater accuracy and speed compared to the general jnp.linalg.eig โ€” and always returns real eigenvalues (no complex artifacts from floating-point asymmetry).

Eigendecompositions of symmetric matrices underpin:

  • PCA โ€” covariance matrix is symmetric PSD; eigenvectors are principal components; eigenvalues are explained variances.
  • Spectral clustering โ€” graph Laplacian is symmetric; cluster structure lives in the smallest eigenvectors.
  • Normal-mode analysis โ€” Hessian of a potential is symmetric; smallest eigenvalues correspond to softest vibrational modes.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[3.0, 1.0],
               [1.0, 3.0]])

eigvals, eigvecs = jnp.linalg.eigh(A)
# eigvals = [2.0, 4.0]  โ€” sorted ASCENDING
# eigvecs: each column is an eigenvector

eigvals[-1]
# โ†’ 4.0  (largest eigenvalue)

Common pitfalls

  • Output is sorted ASCENDING โ€” largest eigenvalue is eigvals[-1].
  • eigh vs eig โ€” use eigh for symmetric/Hermitian matrices only; for general matrices use jnp.linalg.eig (may return complex).
  • Silently uses upper triangle โ€” passing an asymmetric matrix to eigh silently reads only the upper triangle, giving wrong results.

Problem

Implement eigh_largest(A) that returns the largest eigenvalue of a symmetric matrix.

  • A: 2-D jax array (n, n), symmetric.
  • Returns: scalar โ€” largest eigenvalue.

Hints

jax linalg eigh

Sign in to attempt this problem and view the solution.