hard primitives

Higher-Order custom_vjp

Why this matters

A custom_vjp rule defines the backward pass of a function. But what happens when you differentiate through that backward pass itself? JAX traces through the VJP rule just like any other JAX function — so higher-order differentiation works only if the rule itself is JAX-traceable.

This pattern is critical for:

  • Meta-learning (MAML, higher-order MAML) — differentiating through an inner gradient step.
  • Bi-level optimization — the outer objective differentiates through the inner optimizer’s gradient.
  • Physics-informed networks — where loss terms involve PDE residuals differentiated from custom numerical rules.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.custom_vjp
def f(x):
    return jnp.sum(x ** 2)      # f(x) = sum(x²)

def fwd(x):
    return f(x), x              # residuals = x

def bwd(x, g):
    return (g * 2 * x,)         # grad = 2x

f.defvjp(fwd, bwd)

x = jnp.array([1.0, 2.0])
g = jax.grad(f)(x)              # → [2., 4.]  (= 2x)
gg = jax.grad(lambda x: jnp.sum(jax.grad(f)(x)))(x)
# → [2., 2.]  (d/dx of 2x = 2, constant)

Common pitfalls

  • Non-traceable ops in the rule break higher-order grad — Python if/for on traced values, .item(), or NumPy calls inside _bwd will fail at the second differentiation level.
  • The .sum() wrappergrad_fn(x) returns a vector; jax.grad needs a scalar output. Wrap with jnp.sum(grad_fn(x)).
  • Residuals must be JAX arrays — if _fwd returns Python scalars as residuals, higher-order tracing will fail.

Problem

Implement custom_vjp_grad_of_grad(x) that:

  1. Defines a function custom_sumsq(x) = sum(x²) with @jax.custom_vjp.
  2. Registers _fwd (residual = x) and _bwd (returns cotangent * 2 * x).
  3. Computes grad(grad(custom_sumsq))(x) — the second-order gradient.
  • x: 1-D jax array.

Returns: 1-D array same shape — all 2.0s (second derivative of ).

Hints

jax custom-vjp higher-order

Sign in to attempt this problem and view the solution.