We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Custom VJP: Implicit Function Theorem
Why this matters
Many ML primitives are defined implicitly: a fixed-point iteration, an
optimization layer, an optimal-transport map, or a deep equilibrium network.
Differentiating through N iterations unrolls an N-step computation
graph โ memory-intensive and numerically fragile.
The Implicit Function Theorem (IFT) gives a closed-form gradient using only the converged solution, bypassing the iteration entirely. A custom VJP is the clean way to express this in JAX.
For Newton iteration converging to y* = sqrt(x), the IFT gradient is:
d y* / d x = 0.5 / y* = 0.5 / sqrt(x)
This matches d/dx sqrt(x) exactly, but is derived from the fixed-point
condition y* = 0.5 * (y* + x / y*) rather than by differentiating the loop.
IFT gradients power:
- Deep Equilibrium Models (DEQ) โ infinite-depth networks with O(1) memory.
- Differentiable convex solvers (cvxpylayers, OptNet).
- Optimal transport (Sinkhorn iterations).
Worked mini-example
import jax
import jax.numpy as jnp
@jax.custom_vjp
def fixed_point_sqrt(x):
y = jnp.ones_like(x)
for _ in range(10): # Newton: converges to sqrt(x)
y = 0.5 * (y + x / y)
return y
def _fwd(x):
y = fixed_point_sqrt(x)
return y, y # save converged y as residual
def _bwd(y, g):
return (g * 0.5 / y,) # IFT: d sqrt(x)/dx = 0.5/sqrt(x) = 0.5/y
fixed_point_sqrt.defvjp(_fwd, _bwd)
x = jnp.array([4.0, 9.0])
print(fixed_point_sqrt(x)) # [2., 3.]
g = jax.grad(lambda x: fixed_point_sqrt(x).sum())(x)
# โ [0.25, 0.1667] = 0.5 / [2, 3]
# (matches d sqrt(x)/dx โ but computed without unrolling 10 Newton steps)
Common pitfalls
-
Saving the right residual โ
_fwdmust return the convergedy, notx. The_bwdrule usesy = sqrt(x), notxdirectly. -
Trailing comma โ
_bwdmust return(cotangent * 0.5 / y,). Without the comma, it is not a tuple and JAX will raise a type error. -
Without
custom_vjp, JAX would differentiate through all 10 Newton iterations, creating a 10ร deeper computation graph for no benefit. -
Division by zero โ the IFT gradient
0.5 / yis undefined aty = 0(i.e.x = 0). The problem guarantees positive inputs.
Problem
Implement fixed_point_sqrt(x) that computes sqrt(x) via 10 Newton
iterations with a custom_vjp using the IFT gradient 0.5 / sqrt(x).
-
x: 1-D jax array of positive values.
Returns: 1-D array same shape โ sqrt(x).
Hints
Sign in to attempt this problem and view the solution.