medium primitives

Cholesky Decomposition

Why this matters

The Cholesky decomposition factors a symmetric positive semi-definite (PSD) matrix A into L @ L.T, where L is lower triangular. It is roughly 2ร— faster than LU decomposition for PSD matrices and numerically more stable because all pivot values are guaranteed positive.

Cholesky is ubiquitous in scientific computing:

  • Gaussian Processes โ€” factoring the kernel (covariance) matrix.
  • Bayesian inference โ€” posterior covariance in Laplace approximations.
  • Sampling multivariate normals โ€” draw z ~ N(0, I), then L @ z ~ N(0, A).
  • Solving linear systems โ€” solve A x = b via two triangular solves.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[4.0, 2.0],
               [2.0, 5.0]])

L = jnp.linalg.cholesky(A)
# L = [[2.0, 0.0],
#      [1.0, 2.0]]

# Verify: L @ L.T == A
L @ L.T
# โ†’ [[4.0, 2.0],
#    [2.0, 5.0]]

Common pitfalls

  • Non-PSD input โ€” passing a non-PSD matrix produces NaN silently; verify with jnp.linalg.eigvalsh(A) if uncertain.
  • Output is lower triangular โ€” the upper triangle is zero.
  • Symmetric requirement โ€” only the lower triangle is read; ensure A is truly symmetric before calling.

Problem

Implement cholesky_factor(A) that returns the lower-triangular Cholesky factor L such that L @ L.T = A.

  • A: 2-D jax array (n, n), symmetric positive semi-definite.
  • Returns: 2-D array (n, n) โ€” lower triangular L.

Hints

jax linalg cholesky

Sign in to attempt this problem and view the solution.