medium primitives

Newton's Method via lax.while_loop

Why this matters

lax.while_loop is JAX’s primitive for loops whose termination condition depends on a traced (dynamic) value. Unlike Python while, the number of iterations is NOT known at trace time — it is determined at runtime by the data. That makes it the right tool for iterative algorithms like Newton’s method, binary search, or gradient-descent until convergence.

Newton’s method for sqrt(x) is a textbook example: start at y = 1, repeat y ← (y + x/y) / 2 until |y² − x| < tol. The iteration count varies with x and tol, so a Python loop would fail under jit.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def double_until_big(x):
    cond = lambda s: s < 100.0
    body = lambda s: s * 2.0
    return lax.while_loop(cond, body, x)

double_until_big(1.0)   # → 128.0

lax.while_loop(cond_fun, body_fun, init_state) runs while cond_fun(state) is True. Both functions take and return the same state shape.

Common pitfalls

  • Python while under jit. while jnp.abs(y*y - x) >= tol: … raises ConcretizationTypeError — the comparison is a tracer, not a Python bool.
  • State shape mismatch. cond_fun and body_fun must accept the same state; body_fun must also return that same shape.
  • Carry is a tuple. Tracking both iteration count and current value (i, y) is idiomatic and makes debugging easier.

Problem

Implement while_sqrt(x, tol) using lax.while_loop. Iterate Newton’s method y ← (y + x/y) / 2, starting from y = 1.0, until |y² − x| < tol.

  • x: positive scalar.
  • tol: positive scalar (convergence threshold).
  • Returns: scalar — converged y ≈ sqrt(x).

Hints

jax while-loop fixed-point

Sign in to attempt this problem and view the solution.