We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
HVP via grad-of-grad
Why this matters
Computing the full Hessian is O(d²) memory — prohibitive for neural
networks. The Hessian-vector product (HVP) H v requires only O(d)
memory and is the key primitive for:
-
Newton’s method: solve
H Δx = -gvia conjugate gradient using HVPs. - LBFGS and K-FAC approximations.
- Influence functions: trace how training points affect predictions.
The educational baseline uses nested grad: compute
g(x) = ⟨grad(f)(x), v⟩ and then differentiate with respect to x
again. Since grad(f)(x) = A x for our quadratic, g(x) = (Ax)·v = x·(Av),
and its gradient is A v — the correct HVP.
This works but materializes the gradient twice (one inside vdot,
once for the outer grad). The next problem shows how jvp(grad(f))
achieves the same result with a single mixed-mode pass.
Worked mini-example
import jax, jax.numpy as jnp
A = jnp.array([[2.0, 0.0], [0.0, 3.0]])
x = jnp.array([1.0, 1.0])
v = jnp.array([1.0, 0.0])
def f(x):
return 0.5 * x @ A @ x
grad_f = jax.grad(f)
hvp = jax.grad(lambda x: jnp.vdot(grad_f(x), v))(x)
# hvp → [2., 0.] (A @ v = [[2,0],[0,3]] @ [1,0] = [2,0])
Common pitfalls
-
grad(grad(f))(x)gives the full Hessian applied to a unit vector, not the HVP — don’t confuse it withgrad(<grad(f)(x), v>). -
jnp.vdot(a, b)issum(conj(a) * b)— for real arrays it equalsjnp.dot(a, b). Either works here. -
Cost: two reverse-mode passes per HVP call. The next problem
(
jvp-of-grad) does it with one forward + one reverse pass. -
vmust be captured from the outer scope — the inner lambda closes overv, which is not differentiated.
Problem
Implement hvp_grad_of_grad(x, A, v) that computes the Hessian-vector
product H @ v for f(x) = 0.5 * xᵀ A x using the nested-grad pattern.
-
x: 1-D jax array of lengthd. -
A: 2-D jax array(d, d), symmetric. -
v: 1-D jax array of lengthd— the vector to multiply.
Returns: 1-D array of length d — should equal A @ v.
Hints
Sign in to attempt this problem and view the solution.