We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Newton's Method via lax.while_loop
Why this matters
lax.while_loop is JAX’s primitive for loops whose termination condition
depends on a traced (dynamic) value. Unlike Python while, the number of
iterations is NOT known at trace time — it is determined at runtime by the
data. That makes it the right tool for iterative algorithms like Newton’s
method, binary search, or gradient-descent until convergence.
Newton’s method for sqrt(x) is a textbook example: start at y = 1,
repeat y ← (y + x/y) / 2 until |y² − x| < tol. The iteration count
varies with x and tol, so a Python loop would fail under jit.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def double_until_big(x):
cond = lambda s: s < 100.0
body = lambda s: s * 2.0
return lax.while_loop(cond, body, x)
double_until_big(1.0) # → 128.0
lax.while_loop(cond_fun, body_fun, init_state) runs while
cond_fun(state) is True. Both functions take and return the same
state shape.
Common pitfalls
-
Python
whileunder jit.while jnp.abs(y*y - x) >= tol: …raisesConcretizationTypeError— the comparison is a tracer, not a Python bool. -
State shape mismatch.
cond_funandbody_funmust accept the same state;body_funmust also return that same shape. -
Carry is a tuple. Tracking both iteration count and current value
(i, y)is idiomatic and makes debugging easier.
Problem
Implement while_sqrt(x, tol) using lax.while_loop. Iterate Newton’s
method y ← (y + x/y) / 2, starting from y = 1.0, until
|y² − x| < tol.
-
x: positive scalar. -
tol: positive scalar (convergence threshold). -
Returns: scalar — converged
y ≈ sqrt(x).
Hints
Sign in to attempt this problem and view the solution.