We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
HVP via jvp-of-grad
Why this matters
jvp(grad(f), (x,), (v,)) is the production HVP pattern. It computes
H @ v with one reverse-mode pass (to build grad(f)) and one
forward-mode pass (to propagate the tangent v through grad(f)).
This is strictly cheaper than the nested-grad baseline from the previous
problem, which materializes the gradient twice.
The math: the Jacobian-vector product of grad(f) in direction v is
exactly d/dt grad(f)(x + t v)|_{t=0} = H(x) v. Forward-mode evaluates
this directional derivative in one pass, so the total cost is one reverse
plus one forward.
This pattern powers:
-
Conjugate gradient Newton (CG-Newton) solvers that never form
H. - Gauss-Newton and Fisher-information approximations.
- Per-example gradient norms via batched HVPs.
Worked mini-example
import jax, jax.numpy as jnp
A = jnp.array([[2.0, 0.0], [0.0, 3.0]])
x = jnp.array([1.0, 1.0])
v = jnp.array([1.0, 0.0])
def f(x):
return 0.5 * x @ A @ x
grad_f = jax.grad(f)
_, hvp = jax.jvp(grad_f, (x,), (v,))
# hvp โ [2., 0.] (A @ v = [[2,0],[0,3]] @ [1,0] = [2,0])
# primal (discarded) โ grad_f(x) = A @ x = [2., 3.]
Common pitfalls
-
Tuple wrapping is required โ
jax.jvp(grad_f, x, v)raises an error; both primals and tangents must be tuples:(x,)and(v,). -
Take the second element โ
jax.jvpreturns(primal, tangent). We want the tangent (index 1), not the primal gradient. -
Same answer, different cost โ this problem and
jax-hvp-via-grad-of-gradproduce identical outputs; the lesson is the computational graph, not the math. -
jvpcomputes the JVP of its first argument โ here the first argument isgrad_f, so the tangent isH @ v.
Problem
Implement hvp_jvp_of_grad(x, A, v) that computes H @ v for
f(x) = 0.5 * xแต A x using jax.jvp(grad(f)).
-
x: 1-D jax array of lengthd. -
A: 2-D jax array(d, d), symmetric. -
v: 1-D jax array of lengthd.
Returns: 1-D array of length d โ equals A @ v.
The result is numerically identical to the grad-of-grad baseline.
Hints
Sign in to attempt this problem and view the solution.