hard primitives

Custom VJP: Implicit Function Theorem

Why this matters

Many ML primitives are defined implicitly: a fixed-point iteration, an optimization layer, an optimal-transport map, or a deep equilibrium network. Differentiating through N iterations unrolls an N-step computation graph โ€” memory-intensive and numerically fragile.

The Implicit Function Theorem (IFT) gives a closed-form gradient using only the converged solution, bypassing the iteration entirely. A custom VJP is the clean way to express this in JAX.

For Newton iteration converging to y* = sqrt(x), the IFT gradient is:

d y* / d x = 0.5 / y*  =  0.5 / sqrt(x)

This matches d/dx sqrt(x) exactly, but is derived from the fixed-point condition y* = 0.5 * (y* + x / y*) rather than by differentiating the loop.

IFT gradients power:

  • Deep Equilibrium Models (DEQ) โ€” infinite-depth networks with O(1) memory.
  • Differentiable convex solvers (cvxpylayers, OptNet).
  • Optimal transport (Sinkhorn iterations).

Worked mini-example

import jax
import jax.numpy as jnp

@jax.custom_vjp
def fixed_point_sqrt(x):
    y = jnp.ones_like(x)
    for _ in range(10):          # Newton: converges to sqrt(x)
        y = 0.5 * (y + x / y)
    return y

def _fwd(x):
    y = fixed_point_sqrt(x)
    return y, y                  # save converged y as residual

def _bwd(y, g):
    return (g * 0.5 / y,)        # IFT: d sqrt(x)/dx = 0.5/sqrt(x) = 0.5/y

fixed_point_sqrt.defvjp(_fwd, _bwd)

x = jnp.array([4.0, 9.0])
print(fixed_point_sqrt(x))       # [2., 3.]

g = jax.grad(lambda x: fixed_point_sqrt(x).sum())(x)
# โ†’ [0.25, 0.1667]  =  0.5 / [2, 3]
# (matches d sqrt(x)/dx โ€” but computed without unrolling 10 Newton steps)

Common pitfalls

  • Saving the right residual โ€” _fwd must return the converged y, not x. The _bwd rule uses y = sqrt(x), not x directly.
  • Trailing comma โ€” _bwd must return (cotangent * 0.5 / y,). Without the comma, it is not a tuple and JAX will raise a type error.
  • Without custom_vjp, JAX would differentiate through all 10 Newton iterations, creating a 10ร— deeper computation graph for no benefit.
  • Division by zero โ€” the IFT gradient 0.5 / y is undefined at y = 0 (i.e. x = 0). The problem guarantees positive inputs.

Problem

Implement fixed_point_sqrt(x) that computes sqrt(x) via 10 Newton iterations with a custom_vjp using the IFT gradient 0.5 / sqrt(x).

  • x: 1-D jax array of positive values.

Returns: 1-D array same shape โ€” sqrt(x).

Hints

jax custom-vjp implicit-function

Sign in to attempt this problem and view the solution.