We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Bounded Search via lax.while_loop
Why this matters
Many real algorithms scan an array and stop as soon as a condition is met
โ finding the first occurrence of a value, detecting the first NaN, or
locating the end of a variable-length sequence. Under JIT these loops cannot
use Python for with a break; they need lax.while_loop with a carry
that is just the current index.
This problem adds a real-world wrinkle: the loop must be bounded to
avoid reading past the end of the array. Passing an uncapped max_iters
that exceeds len(x) would cause an out-of-bounds index at runtime.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def find_negative(x):
n = x.shape[0]
cond = lambda idx: (idx < n) & (x[idx] >= 0.0)
body = lambda idx: idx + 1
return lax.while_loop(cond, body, jnp.int32(0))
find_negative(jnp.array([1.0, 2.0, -3.0, 4.0]))
# โ 2
Key: use & (bitwise AND), NOT Python and, when combining tracer
booleans. Python and short-circuits on the Python bool of the left side,
which is undefined for abstract tracers.
Common pitfalls
-
Python
andon tracers.(idx < n) and (x[idx] != 0.0)raisesTracerBoolConversionError. Always use&. -
Not capping
max_iters. Ifmax_iters > len(x), the loop tries to indexx[len(x)]which is out-of-bounds. Cap it:min(max_iters, n). -
Wrong carry dtype. Use
jnp.int32(0)as init โ plain Python0can produce an int64 carry that mismatches on some backends.
Problem
Implement while_find_first_zero(x, max_iters) using lax.while_loop.
Walk x one element at a time, stopping at the first zero. If no zero is
found within min(max_iters, len(x)) steps, return that bound.
-
x: 1-D jax array. -
max_iters: scalar (cast to int32 internally). -
Returns: scalar int32 โ index of first zero, or
min(len(x), max_iters)if none is found within the limit.
Hints
Sign in to attempt this problem and view the solution.