We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
fori vs scan vs while
Why this matters
JAX provides three structured loop primitives, each with a different API and use-case profile:
-
lax.fori_loopโ simplest, for a fixed iteration count with no per-step output. Body is(i, carry) โ carry. -
lax.scanโ for a fixed count where you want the full history of intermediate outputs. Body is(carry, x) โ (carry, y);xsandysare stacked automatically. -
lax.while_loopโ for data-dependent termination (iteration count unknown at trace time). Body isstate โ state.
Knowing which primitive fits which problem โ and what their body signatures are โ is essential for writing correct JAX control flow.
Worked mini-example
from jax import lax
import jax.numpy as jnp
# Cumulative sum three ways
n = 4
# fori: no per-step output, just final acc
total = lax.fori_loop(0, n, lambda i, acc: acc + i, 0)
# scan: accumulate AND collect per-step carry
carry, ys = lax.scan(lambda acc, x: (acc + x, acc + x), 0, jnp.arange(n))
# while: stop when data condition met
_, total_w = lax.while_loop(
lambda s: s[0] < n,
lambda s: (s[0] + 1, s[1] + s[0]),
(0, 0)
)
Common pitfalls
-
Mixing up body signatures.
foriis(i, carry),scanis(carry, x),whileis(state,). Confusing them silently transposes arguments. -
scan needs a static
xsshape.lax.scan(f, init, xs)traces over the leading axis ofxs. Ifnis a Python-level int you can writejnp.ones(n)directly; ifnwere a JAX tracer, the shape would be unknown at trace time and scan would fail. -
while carry dtype. The counter
ishould be an int (e.g.0, not0.0) so the comparisoni < nstays integer.
Problem
Implement power_three_ways(x, n) that computes x^n using each of the
three primitives and returns all three results as a 1-D array.
-
x: scalar. -
n: scalar (treated as a static Python int internally โ cast withint(n)). -
Returns: 1-D array of shape
(3,)โ[fori_result, scan_result, while_result]. All three values should equalx^n.
Hints
jax
fori-loop
scan
while-loop
Sign in to attempt this problem and view the solution.