We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
linalg.inv
Why this matters
jnp.linalg.inv(A) computes the explicit matrix inverse A⁻¹ such that
A⁻¹ A = I. Unlike jnp.linalg.solve, it materialises the full inverse
matrix — which is expensive (O(n³)) and numerically fragile for
ill-conditioned A.
Knowing when to use it matters as much as knowing how:
prefer solve(A, b) whenever you want A⁻¹b. Use explicit inv only
when you genuinely need the matrix itself — for example in Kalman filter
covariance update formulæ (P ← (I − KH)P) or Gaussian process posterior
computations where the inverse covariance matrix appears directly in the
predictive formula.
Worked mini-example
import jax.numpy as jnp
A = jnp.array([[2.0, 0.0],
[0.0, 4.0]])
Ainv = jnp.linalg.inv(A)
# → [[0.5, 0.0 ],
# [0.0, 0.25]]
# Verify: A @ Ainv ≈ I
Common pitfalls
-
Prefer
solveoverinv(A) @ bfor solving linear systems — it is faster and more stable. -
Ill-conditioned A — the inverse will contain large round-off errors.
Check
jnp.linalg.cond(A)before inverting. -
Singular A raises
LinAlgError. Regularise with a small diagonal term (A + ε * I) if needed. - Cost is O(n³) — avoid for large n unless the full matrix is required.
Problem
Implement matrix_inverse(A) that returns A⁻¹ using jnp.linalg.inv.
-
A: 2-D jax array (n, n) — square, non-singular. - Returns: 2-D array (n, n) — the matrix inverse.
Hints
Sign in to attempt this problem and view the solution.