We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Hessian of a Quadratic
Why this matters
jax.hessian(f) is syntactic sugar for jax.jacfwd(jax.jacrev(f)) — the
canonical mixed-mode pattern: reverse-mode for the inner gradient, then
forward-mode for the outer Jacobian. This gives an O(d²) matrix in two
passes instead of O(d) passes.
Hessians are the workhorse of second-order optimization:
-
Newton’s method solves
H Δx = -gat each step. -
Curvature analysis uses eigenvalues of
Hto detect saddle points and condition numbers. - Laplace approximations in Bayesian deep learning build a Gaussian around a MAP estimate using the local Hessian.
For f(x) = 0.5 * xᵀ A x with symmetric A, the exact Hessian is A —
a clean sanity-check that jax.hessian returns the right answer.
Worked mini-example
import jax, jax.numpy as jnp
def f(x):
return x[0] ** 2 + x[1] ** 2 # = 0.5 * xᵀ [[2,0],[0,2]] x
x = jnp.array([3.0, 4.0])
H = jax.hessian(f)(x)
# H → [[2., 0.],
# [0., 2.]]
# Equals A = [[2, 0], [0, 2]] ✓
Common pitfalls
- Full Hessian is O(d²) memory — for large models prefer the Hessian-vector product (HVP) pattern instead.
-
jax.hessianonly works on scalar-valued functions — iffreturns a vector, usejax.jacfwd(jax.jacrev(f))explicitly. -
Forgetting
0.5— without it,f = xᵀ A xgives HessianA + Aᵀ = 2A, notA. - A must be symmetric for the Hessian to equal A; the problem guarantees this, but it is worth keeping in mind.
Problem
Implement hessian_quadratic(x, A) that returns the Hessian of
f(x) = 0.5 * xᵀ A x at x using jax.hessian.
-
x: 1-D jax array of lengthd. -
A: 2-D jax array(d, d), symmetric.
Returns: 2-D array (d, d) — should equal A (within float precision).
Hints
Sign in to attempt this problem and view the solution.