easy primitives

Hessian of a Quadratic

Why this matters

jax.hessian(f) is syntactic sugar for jax.jacfwd(jax.jacrev(f)) — the canonical mixed-mode pattern: reverse-mode for the inner gradient, then forward-mode for the outer Jacobian. This gives an O(d²) matrix in two passes instead of O(d) passes.

Hessians are the workhorse of second-order optimization:

  • Newton’s method solves H Δx = -g at each step.
  • Curvature analysis uses eigenvalues of H to detect saddle points and condition numbers.
  • Laplace approximations in Bayesian deep learning build a Gaussian around a MAP estimate using the local Hessian.

For f(x) = 0.5 * xᵀ A x with symmetric A, the exact Hessian is A — a clean sanity-check that jax.hessian returns the right answer.

Worked mini-example

import jax, jax.numpy as jnp

def f(x):
    return x[0] ** 2 + x[1] ** 2   # = 0.5 * xᵀ [[2,0],[0,2]] x

x = jnp.array([3.0, 4.0])
H = jax.hessian(f)(x)
# H → [[2., 0.],
#       [0., 2.]]
# Equals A = [[2, 0], [0, 2]] ✓

Common pitfalls

  • Full Hessian is O(d²) memory — for large models prefer the Hessian-vector product (HVP) pattern instead.
  • jax.hessian only works on scalar-valued functions — if f returns a vector, use jax.jacfwd(jax.jacrev(f)) explicitly.
  • Forgetting 0.5 — without it, f = xᵀ A x gives Hessian A + Aᵀ = 2A, not A.
  • A must be symmetric for the Hessian to equal A; the problem guarantees this, but it is worth keeping in mind.

Problem

Implement hessian_quadratic(x, A) that returns the Hessian of f(x) = 0.5 * xᵀ A x at x using jax.hessian.

  • x: 1-D jax array of length d.
  • A: 2-D jax array (d, d), symmetric.

Returns: 2-D array (d, d) — should equal A (within float precision).

Hints

jax hessian

Sign in to attempt this problem and view the solution.