We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
linalg.norm (Frobenius)
Why this matters
jnp.linalg.norm is the go-to function for measuring “size” of vectors
and matrices. The default behaviour differs by rank:
-
1-D input → L2 (Euclidean) norm:
sqrt(sum(x²)). -
2-D input → Frobenius norm:
sqrt(sum(A²))(not spectral).
The Frobenius norm treats the matrix as a flat vector of entries, making
it cheap (O(mn)) and differentiable everywhere. It appears in weight decay
regularisation (‖W‖_F²), gradient clipping, and distance metrics between
parameter tensors.
Worked mini-example
import jax.numpy as jnp
A = jnp.array([[3.0, 0.0],
[0.0, 4.0]])
jnp.linalg.norm(A) # → 5.0 (Frobenius = sqrt(9+16))
jnp.linalg.norm(A, ord=2) # → 4.0 (spectral = largest singular value)
jnp.linalg.norm(A, ord='fro')# → 5.0 (explicit Frobenius)
Common pitfalls
-
Default for 2-D is Frobenius, not spectral. If you want the largest
singular value use
ord=2. -
ord=Nonevsord='fro'— both give Frobenius for 2-D;Noneis the default. -
For 1-D: default (
ord=None) gives the L2 norm; other ords behave as expected (L1, Linf, etc.). -
Batched norms — wrap with
jax.vmapto norm a batch of matrices.
Problem
Implement matrix_frobenius(A) that returns the Frobenius norm of a 2-D
matrix using jnp.linalg.norm.
-
A: 2-D jax array of any shape. -
Returns: scalar —
sqrt(sum(A²)).
Hints
jax
linalg
norm
Sign in to attempt this problem and view the solution.