medium primitives

HVP via jvp-of-grad

Why this matters

jvp(grad(f), (x,), (v,)) is the production HVP pattern. It computes H @ v with one reverse-mode pass (to build grad(f)) and one forward-mode pass (to propagate the tangent v through grad(f)). This is strictly cheaper than the nested-grad baseline from the previous problem, which materializes the gradient twice.

The math: the Jacobian-vector product of grad(f) in direction v is exactly d/dt grad(f)(x + t v)|_{t=0} = H(x) v. Forward-mode evaluates this directional derivative in one pass, so the total cost is one reverse plus one forward.

This pattern powers:

  • Conjugate gradient Newton (CG-Newton) solvers that never form H.
  • Gauss-Newton and Fisher-information approximations.
  • Per-example gradient norms via batched HVPs.

Worked mini-example

import jax, jax.numpy as jnp

A = jnp.array([[2.0, 0.0], [0.0, 3.0]])
x = jnp.array([1.0, 1.0])
v = jnp.array([1.0, 0.0])

def f(x):
    return 0.5 * x @ A @ x

grad_f = jax.grad(f)
_, hvp = jax.jvp(grad_f, (x,), (v,))
# hvp โ†’ [2., 0.]   (A @ v = [[2,0],[0,3]] @ [1,0] = [2,0])
# primal (discarded) โ†’ grad_f(x) = A @ x = [2., 3.]

Common pitfalls

  • Tuple wrapping is required โ€” jax.jvp(grad_f, x, v) raises an error; both primals and tangents must be tuples: (x,) and (v,).
  • Take the second element โ€” jax.jvp returns (primal, tangent). We want the tangent (index 1), not the primal gradient.
  • Same answer, different cost โ€” this problem and jax-hvp-via-grad-of-grad produce identical outputs; the lesson is the computational graph, not the math.
  • jvp computes the JVP of its first argument โ€” here the first argument is grad_f, so the tangent is H @ v.

Problem

Implement hvp_jvp_of_grad(x, A, v) that computes H @ v for f(x) = 0.5 * xแต€ A x using jax.jvp(grad(f)).

  • x: 1-D jax array of length d.
  • A: 2-D jax array (d, d), symmetric.
  • v: 1-D jax array of length d.

Returns: 1-D array of length d โ€” equals A @ v.

The result is numerically identical to the grad-of-grad baseline.

Hints

jax hvp jvp

Sign in to attempt this problem and view the solution.