We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
linalg.matrix_power
Why this matters
jnp.linalg.matrix_power(A, n) computes the matrix power A^n โ the
result of multiplying A by itself n times. This is fundamentally different
from A ** n, which raises each element to the nth power (elementwise).
Matrix powers appear in:
- Markov chains โ the n-step transition matrix is T^n.
- Recurrence relations โ closed-form Fibonacci in O(log n) multiplications.
- Discrete dynamical systems โ state after n steps is A^n @ xโ.
For large n, the implementation uses fast exponentiation (repeated squaring) in O(log n) matrix multiplications instead of O(n).
Worked mini-example
import jax.numpy as jnp
# Swap matrix: swaps the two elements
A = jnp.array([[0.0, 1.0],
[1.0, 0.0]])
jnp.linalg.matrix_power(A, 2)
# โ [[1.0, 0.0],
# [0.0, 1.0]] (swap twice = identity)
jnp.linalg.matrix_power(A, 3)
# โ [[0.0, 1.0],
# [1.0, 0.0]] (swap once = original)
Common pitfalls
-
`A n
is elementwise** โ not the same asmatrix_power(A, n)`. -
n must be a Python int โ JAX traces shapes statically; pass
int(n). -
Negative n โ computes
inv(A)^|n|; requires A to be invertible. - n=0 โ returns the identity matrix regardless of A.
Problem
Implement matrix_power(A, n) that returns A^n (matrix to the nth power).
-
A: 2-D jax array (m, m) โ square matrix. -
n: scalar โ the exponent (cast to int inside the function). - Returns: 2-D array (m, m) โ A multiplied by itself n times.
Hints
jax
linalg
matrix-power
Sign in to attempt this problem and view the solution.