medium primitives

HVP via grad-of-grad

Why this matters

Computing the full Hessian is O(d²) memory — prohibitive for neural networks. The Hessian-vector product (HVP) H v requires only O(d) memory and is the key primitive for:

  • Newton’s method: solve H Δx = -g via conjugate gradient using HVPs.
  • LBFGS and K-FAC approximations.
  • Influence functions: trace how training points affect predictions.

The educational baseline uses nested grad: compute g(x) = ⟨grad(f)(x), v⟩ and then differentiate with respect to x again. Since grad(f)(x) = A x for our quadratic, g(x) = (Ax)·v = x·(Av), and its gradient is A v — the correct HVP.

This works but materializes the gradient twice (one inside vdot, once for the outer grad). The next problem shows how jvp(grad(f)) achieves the same result with a single mixed-mode pass.

Worked mini-example

import jax, jax.numpy as jnp

A = jnp.array([[2.0, 0.0], [0.0, 3.0]])
x = jnp.array([1.0, 1.0])
v = jnp.array([1.0, 0.0])

def f(x):
    return 0.5 * x @ A @ x

grad_f = jax.grad(f)
hvp = jax.grad(lambda x: jnp.vdot(grad_f(x), v))(x)
# hvp → [2., 0.]   (A @ v = [[2,0],[0,3]] @ [1,0] = [2,0])

Common pitfalls

  • grad(grad(f))(x) gives the full Hessian applied to a unit vector, not the HVP — don’t confuse it with grad(<grad(f)(x), v>).
  • jnp.vdot(a, b) is sum(conj(a) * b) — for real arrays it equals jnp.dot(a, b). Either works here.
  • Cost: two reverse-mode passes per HVP call. The next problem (jvp-of-grad) does it with one forward + one reverse pass.
  • v must be captured from the outer scope — the inner lambda closes over v, which is not differentiated.

Problem

Implement hvp_grad_of_grad(x, A, v) that computes the Hessian-vector product H @ v for f(x) = 0.5 * xᵀ A x using the nested-grad pattern.

  • x: 1-D jax array of length d.
  • A: 2-D jax array (d, d), symmetric.
  • v: 1-D jax array of length d — the vector to multiply.

Returns: 1-D array of length d — should equal A @ v.

Hints

jax hvp hessian

Sign in to attempt this problem and view the solution.