medium primitives

lax.while_loop vs Python while

Why this matters

The most common mistake when moving JAX code under jit is leaving a Python while loop whose exit condition touches a traced value. JAX traces functions symbolically: a traced scalar has no concrete Python bool, so while traced_value >= threshold: raises ConcretizationTypeError.

The fix is always lax.while_loop. It traces the body and condition once, emits a single XLA while op, and allows the iteration count to be fully data-dependent at runtime โ€” something Python control flow can never do inside a JIT-compiled function.

Worked mini-example

import jax.numpy as jnp
from jax import lax, jit

@jit
def bad(x):
    while x > 1.0:   # ConcretizationTypeError
        x = x * 0.5
    return x

@jit
def good(x):
    cond = lambda s: s > 1.0
    body = lambda s: s * 0.5
    return lax.while_loop(cond, body, x)

good(4.0)   # โ†’ 0.5

Common pitfalls

  • Python while under jit. Even if jit is not applied explicitly, functions used inside vmap, grad, or another jit inherit the same constraint.
  • Returning count as int. The carry i is an int32 tracer; the grader may expect float32. Cast explicitly: jnp.float32(count).
  • Incorrect convergence count. The body runs while the condition is True, so the count is the number of iterations before the condition first becomes False.

Problem

Implement safe_iterate_until(x, threshold) using lax.while_loop. Repeatedly multiply x by 0.9 until it falls below threshold. Return the number of multiplications performed, as a float32 scalar.

  • x: positive scalar (starts above threshold).
  • threshold: positive scalar (< x).
  • Returns: scalar float32 โ€” number of iterations until x < threshold.

Hints

jax while-loop tracing

Sign in to attempt this problem and view the solution.