easy primitives

linalg.norm (Frobenius)

Why this matters

jnp.linalg.norm is the go-to function for measuring “size” of vectors and matrices. The default behaviour differs by rank:

  • 1-D input → L2 (Euclidean) norm: sqrt(sum(x²)).
  • 2-D input → Frobenius norm: sqrt(sum(A²)) (not spectral).

The Frobenius norm treats the matrix as a flat vector of entries, making it cheap (O(mn)) and differentiable everywhere. It appears in weight decay regularisation (‖W‖_F²), gradient clipping, and distance metrics between parameter tensors.

Worked mini-example

import jax.numpy as jnp

A = jnp.array([[3.0, 0.0],
               [0.0, 4.0]])

jnp.linalg.norm(A)           # → 5.0  (Frobenius = sqrt(9+16))
jnp.linalg.norm(A, ord=2)    # → 4.0  (spectral = largest singular value)
jnp.linalg.norm(A, ord='fro')# → 5.0  (explicit Frobenius)

Common pitfalls

  • Default for 2-D is Frobenius, not spectral. If you want the largest singular value use ord=2.
  • ord=None vs ord='fro' — both give Frobenius for 2-D; None is the default.
  • For 1-D: default (ord=None) gives the L2 norm; other ords behave as expected (L1, Linf, etc.).
  • Batched norms — wrap with jax.vmap to norm a batch of matrices.

Problem

Implement matrix_frobenius(A) that returns the Frobenius norm of a 2-D matrix using jnp.linalg.norm.

  • A: 2-D jax array of any shape.
  • Returns: scalar — sqrt(sum(A²)).

Hints

jax linalg norm

Sign in to attempt this problem and view the solution.