We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Higher-Order custom_vjp
Why this matters
A custom_vjp rule defines the backward pass of a function. But what
happens when you differentiate through that backward pass itself?
JAX traces through the VJP rule just like any other JAX function — so
higher-order differentiation works only if the rule itself is JAX-traceable.
This pattern is critical for:
- Meta-learning (MAML, higher-order MAML) — differentiating through an inner gradient step.
- Bi-level optimization — the outer objective differentiates through the inner optimizer’s gradient.
- Physics-informed networks — where loss terms involve PDE residuals differentiated from custom numerical rules.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.custom_vjp
def f(x):
return jnp.sum(x ** 2) # f(x) = sum(x²)
def fwd(x):
return f(x), x # residuals = x
def bwd(x, g):
return (g * 2 * x,) # grad = 2x
f.defvjp(fwd, bwd)
x = jnp.array([1.0, 2.0])
g = jax.grad(f)(x) # → [2., 4.] (= 2x)
gg = jax.grad(lambda x: jnp.sum(jax.grad(f)(x)))(x)
# → [2., 2.] (d/dx of 2x = 2, constant)
Common pitfalls
-
Non-traceable ops in the rule break higher-order grad — Python
if/foron traced values,.item(), or NumPy calls inside_bwdwill fail at the second differentiation level. -
The
.sum()wrapper —grad_fn(x)returns a vector;jax.gradneeds a scalar output. Wrap withjnp.sum(grad_fn(x)). -
Residuals must be JAX arrays — if
_fwdreturns Python scalars as residuals, higher-order tracing will fail.
Problem
Implement custom_vjp_grad_of_grad(x) that:
-
Defines a function
custom_sumsq(x) = sum(x²)with@jax.custom_vjp. -
Registers
_fwd(residual =x) and_bwd(returnscotangent * 2 * x). -
Computes
grad(grad(custom_sumsq))(x)— the second-order gradient.
-
x: 1-D jax array.
Returns: 1-D array same shape — all 2.0s (second derivative of x²).
Hints
jax
custom-vjp
higher-order
Sign in to attempt this problem and view the solution.