hard primitives

Bounded Search via lax.while_loop

Why this matters

Many real algorithms scan an array and stop as soon as a condition is met โ€” finding the first occurrence of a value, detecting the first NaN, or locating the end of a variable-length sequence. Under JIT these loops cannot use Python for with a break; they need lax.while_loop with a carry that is just the current index.

This problem adds a real-world wrinkle: the loop must be bounded to avoid reading past the end of the array. Passing an uncapped max_iters that exceeds len(x) would cause an out-of-bounds index at runtime.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def find_negative(x):
    n = x.shape[0]
    cond = lambda idx: (idx < n) & (x[idx] >= 0.0)
    body = lambda idx: idx + 1
    return lax.while_loop(cond, body, jnp.int32(0))

find_negative(jnp.array([1.0, 2.0, -3.0, 4.0]))
# โ†’ 2

Key: use & (bitwise AND), NOT Python and, when combining tracer booleans. Python and short-circuits on the Python bool of the left side, which is undefined for abstract tracers.

Common pitfalls

  • Python and on tracers. (idx < n) and (x[idx] != 0.0) raises TracerBoolConversionError. Always use &.
  • Not capping max_iters. If max_iters > len(x), the loop tries to index x[len(x)] which is out-of-bounds. Cap it: min(max_iters, n).
  • Wrong carry dtype. Use jnp.int32(0) as init โ€” plain Python 0 can produce an int64 carry that mismatches on some backends.

Problem

Implement while_find_first_zero(x, max_iters) using lax.while_loop. Walk x one element at a time, stopping at the first zero. If no zero is found within min(max_iters, len(x)) steps, return that bound.

  • x: 1-D jax array.
  • max_iters: scalar (cast to int32 internally).
  • Returns: scalar int32 โ€” index of first zero, or min(len(x), max_iters) if none is found within the limit.

Hints

jax while-loop search

Sign in to attempt this problem and view the solution.