medium primitives

fori vs scan vs while

Why this matters

JAX provides three structured loop primitives, each with a different API and use-case profile:

  • lax.fori_loop โ€” simplest, for a fixed iteration count with no per-step output. Body is (i, carry) โ†’ carry.
  • lax.scan โ€” for a fixed count where you want the full history of intermediate outputs. Body is (carry, x) โ†’ (carry, y); xs and ys are stacked automatically.
  • lax.while_loop โ€” for data-dependent termination (iteration count unknown at trace time). Body is state โ†’ state.

Knowing which primitive fits which problem โ€” and what their body signatures are โ€” is essential for writing correct JAX control flow.

Worked mini-example

from jax import lax
import jax.numpy as jnp

# Cumulative sum three ways
n = 4

# fori: no per-step output, just final acc
total = lax.fori_loop(0, n, lambda i, acc: acc + i, 0)

# scan: accumulate AND collect per-step carry
carry, ys = lax.scan(lambda acc, x: (acc + x, acc + x), 0, jnp.arange(n))

# while: stop when data condition met
_, total_w = lax.while_loop(
    lambda s: s[0] < n,
    lambda s: (s[0] + 1, s[1] + s[0]),
    (0, 0)
)

Common pitfalls

  • Mixing up body signatures. fori is (i, carry), scan is (carry, x), while is (state,). Confusing them silently transposes arguments.
  • scan needs a static xs shape. lax.scan(f, init, xs) traces over the leading axis of xs. If n is a Python-level int you can write jnp.ones(n) directly; if n were a JAX tracer, the shape would be unknown at trace time and scan would fail.
  • while carry dtype. The counter i should be an int (e.g. 0, not 0.0) so the comparison i < n stays integer.

Problem

Implement power_three_ways(x, n) that computes x^n using each of the three primitives and returns all three results as a 1-D array.

  • x: scalar.
  • n: scalar (treated as a static Python int internally โ€” cast with int(n)).
  • Returns: 1-D array of shape (3,) โ€” [fori_result, scan_result, while_result]. All three values should equal x^n.

Hints

jax fori-loop scan while-loop

Sign in to attempt this problem and view the solution.