easy primitives

linalg.inv

Why this matters

jnp.linalg.inv(A) computes the explicit matrix inverse A⁻¹ such that A⁻¹ A = I. Unlike jnp.linalg.solve, it materialises the full inverse matrix — which is expensive (O(n³)) and numerically fragile for ill-conditioned A.

Knowing when to use it matters as much as knowing how: prefer solve(A, b) whenever you want A⁻¹b. Use explicit inv only when you genuinely need the matrix itself — for example in Kalman filter covariance update formulæ (P ← (I − KH)P) or Gaussian process posterior computations where the inverse covariance matrix appears directly in the predictive formula.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[2.0, 0.0],
               [0.0, 4.0]])

Ainv = jnp.linalg.inv(A)
# → [[0.5,  0.0 ],
#    [0.0,  0.25]]

# Verify: A @ Ainv ≈ I

Common pitfalls

  • Prefer solve over inv(A) @ b for solving linear systems — it is faster and more stable.
  • Ill-conditioned A — the inverse will contain large round-off errors. Check jnp.linalg.cond(A) before inverting.
  • Singular A raises LinAlgError. Regularise with a small diagonal term (A + ε * I) if needed.
  • Cost is O(n³) — avoid for large n unless the full matrix is required.

Problem

Implement matrix_inverse(A) that returns A⁻¹ using jnp.linalg.inv.

  • A: 2-D jax array (n, n) — square, non-singular.
  • Returns: 2-D array (n, n) — the matrix inverse.

Hints

jax linalg inv

Sign in to attempt this problem and view the solution.