medium primitives

QR Decomposition

Why this matters

The QR decomposition factors A = Q @ R, where Q has orthonormal columns and R is upper triangular. This is a cornerstone of numerical linear algebra:

  • Least squares β€” solving overdetermined Ax β‰ˆ b via the normal equations without forming A.T @ A (which squares the condition number).
  • Gram-Schmidt orthogonalization β€” QR is the matrix form of this process.
  • QR algorithm β€” the standard method for computing all eigenvalues.
  • Orthonormal bases β€” Q’s columns span the column space of A.

JAX’s jnp.linalg.qr uses the β€œreduced” (thin) mode by default: for an (m, n) matrix with m β‰₯ n, Q is (m, k) and R is (k, n) where k = min(m, n).

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[1.0, 0.0],
               [0.0, 1.0],
               [1.0, 1.0]])   # shape (3, 2)

Q, R = jnp.linalg.qr(A)
# Q: shape (3, 2) β€” orthonormal columns (Q.T @ Q = I_2)
# R: shape (2, 2) β€” upper triangular

jnp.linalg.norm(Q)
# β†’ sqrt(2) β‰ˆ 1.4142  (two orthonormal columns in 3-D)

Common pitfalls

  • Full vs reduced mode β€” mode='complete' gives square Q (m, m); default is mode='reduced' giving thin Q (m, k).
  • Sign ambiguity β€” Q columns are defined up to sign; don’t compare individual columns, only derived quantities like Q.T @ Q.
  • Non-square R β€” for tall A, R is (k, n), not (m, n).

Problem

Implement qr_decompose_q_norm(A) that returns the Frobenius norm of the Q factor from the QR decomposition. Since Q has orthonormal columns, β€–Qβ€–_F = sqrt(min(m, n)).

  • A: 2-D jax array (m, n).
  • Returns: scalar β€” Frobenius norm of Q.

Hints

jax linalg qr

Sign in to attempt this problem and view the solution.