medium primitives

linalg.matrix_power

Why this matters

jnp.linalg.matrix_power(A, n) computes the matrix power A^n โ€” the result of multiplying A by itself n times. This is fundamentally different from A ** n, which raises each element to the nth power (elementwise).

Matrix powers appear in:

  • Markov chains โ€” the n-step transition matrix is T^n.
  • Recurrence relations โ€” closed-form Fibonacci in O(log n) multiplications.
  • Discrete dynamical systems โ€” state after n steps is A^n @ xโ‚€.

For large n, the implementation uses fast exponentiation (repeated squaring) in O(log n) matrix multiplications instead of O(n).

Worked mini-example

import jax.numpy as jnp

# Swap matrix: swaps the two elements
A = jnp.array([[0.0, 1.0],
               [1.0, 0.0]])

jnp.linalg.matrix_power(A, 2)
# โ†’ [[1.0, 0.0],
#    [0.0, 1.0]]   (swap twice = identity)

jnp.linalg.matrix_power(A, 3)
# โ†’ [[0.0, 1.0],
#    [1.0, 0.0]]   (swap once = original)

Common pitfalls

  • `A nis elementwise** โ€” not the same asmatrix_power(A, n)`.
  • n must be a Python int โ€” JAX traces shapes statically; pass int(n).
  • Negative n โ€” computes inv(A)^|n|; requires A to be invertible.
  • n=0 โ€” returns the identity matrix regardless of A.

Problem

Implement matrix_power(A, n) that returns A^n (matrix to the nth power).

  • A: 2-D jax array (m, m) โ€” square matrix.
  • n: scalar โ€” the exponent (cast to int inside the function).
  • Returns: 2-D array (m, m) โ€” A multiplied by itself n times.

Hints

jax linalg matrix-power

Sign in to attempt this problem and view the solution.