medium primitives

linalg.solve

Why this matters

jnp.linalg.solve(A, b) solves the linear system Ax = b for x without forming the inverse of A explicitly. Under the hood it uses LU factorisation (LAPACK gesv), which is both more numerically stable and faster than the naΓ―ve alternative jnp.linalg.inv(A) @ b.

The difference matters in practice: for an ill-conditioned matrix (large condition number) the inverse accumulates rounding error, while the solver applies partial pivoting that keeps round-off small. For a square non-singular A this is almost always the right choice.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[2.0, 1.0],
               [1.0, 3.0]])
b = jnp.array([4.0, 5.0])

x = jnp.linalg.solve(A, b)
# x β‰ˆ [1.4, 1.2]   (verify: A @ x β‰ˆ b)

Common pitfalls

  • Singular A raises LinAlgError β€” check the condition number first.
  • A must be square β€” for rectangular systems use jnp.linalg.lstsq.
  • Symmetric PSD A? Use jax.scipy.linalg.cho_solve for ~2Γ— speedup.
  • Batch of systems β€” jax.vmap(jnp.linalg.solve, in_axes=(0, 0))(As, bs) solves each pair; no loop required.

Problem

Implement solve_linear_system(A, b) that returns the solution x to Ax = b.

  • A: 2-D jax array (n, n) β€” square, non-singular coefficient matrix.
  • b: 1-D jax array (n,) β€” right-hand-side vector.
  • Returns: 1-D array (n,) β€” solution x such that A @ x β‰ˆ b.

Hints

jax linalg solve

Sign in to attempt this problem and view the solution.