We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Cholesky Decomposition
Why this matters
The Cholesky decomposition factors a symmetric positive semi-definite (PSD) matrix A into L @ L.T, where L is lower triangular. It is roughly 2ร faster than LU decomposition for PSD matrices and numerically more stable because all pivot values are guaranteed positive.
Cholesky is ubiquitous in scientific computing:
- Gaussian Processes โ factoring the kernel (covariance) matrix.
- Bayesian inference โ posterior covariance in Laplace approximations.
- Sampling multivariate normals โ draw z ~ N(0, I), then L @ z ~ N(0, A).
- Solving linear systems โ solve A x = b via two triangular solves.
Worked mini-example
import jax.numpy as jnp
A = jnp.array([[4.0, 2.0],
[2.0, 5.0]])
L = jnp.linalg.cholesky(A)
# L = [[2.0, 0.0],
# [1.0, 2.0]]
# Verify: L @ L.T == A
L @ L.T
# โ [[4.0, 2.0],
# [2.0, 5.0]]
Common pitfalls
-
Non-PSD input โ passing a non-PSD matrix produces NaN silently;
verify with
jnp.linalg.eigvalsh(A)if uncertain. - Output is lower triangular โ the upper triangle is zero.
- Symmetric requirement โ only the lower triangle is read; ensure A is truly symmetric before calling.
Problem
Implement cholesky_factor(A) that returns the lower-triangular Cholesky
factor L such that L @ L.T = A.
-
A: 2-D jax array (n, n), symmetric positive semi-definite. - Returns: 2-D array (n, n) โ lower triangular L.
Hints
jax
linalg
cholesky
Sign in to attempt this problem and view the solution.