We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
lax.while_loop vs Python while
Why this matters
The most common mistake when moving JAX code under jit is leaving a
Python while loop whose exit condition touches a traced value. JAX traces
functions symbolically: a traced scalar has no concrete Python bool, so
while traced_value >= threshold: raises ConcretizationTypeError.
The fix is always lax.while_loop. It traces the body and condition
once, emits a single XLA while op, and allows the iteration count to
be fully data-dependent at runtime โ something Python control flow can never
do inside a JIT-compiled function.
Worked mini-example
import jax.numpy as jnp
from jax import lax, jit
@jit
def bad(x):
while x > 1.0: # ConcretizationTypeError
x = x * 0.5
return x
@jit
def good(x):
cond = lambda s: s > 1.0
body = lambda s: s * 0.5
return lax.while_loop(cond, body, x)
good(4.0) # โ 0.5
Common pitfalls
-
Python
whileunder jit. Even ifjitis not applied explicitly, functions used insidevmap,grad, or anotherjitinherit the same constraint. -
Returning count as int. The carry
iis an int32 tracer; the grader may expect float32. Cast explicitly:jnp.float32(count). - Incorrect convergence count. The body runs while the condition is True, so the count is the number of iterations before the condition first becomes False.
Problem
Implement safe_iterate_until(x, threshold) using lax.while_loop.
Repeatedly multiply x by 0.9 until it falls below threshold.
Return the number of multiplications performed, as a float32 scalar.
-
x: positive scalar (starts above threshold). -
threshold: positive scalar (< x). -
Returns: scalar float32 โ number of iterations until
x < threshold.
Hints
Sign in to attempt this problem and view the solution.