We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
x^n via lax.fori_loop
Why this matters
lax.fori_loop(lower, upper, body_fun, init_val) is JAXโs primitive for
bounded-counter loops โ loops where the iteration count is a known
(concrete) integer at compile time. Unlike lax.while_loop, the termination
condition is not data-dependent; you simply supply a lower and upper bound and
the loop runs for exactly upper - lower iterations.
This makes it the natural replacement for for i in range(n): โฆ under jit.
Python for loops can work under jit when n is a static Python int, but
lax.fori_loop is the explicit, XLA-idiomatic way to express the same intent
and works correctly when n is a JAX integer.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def sum_n(n):
# Sum 0..n-1 using fori_loop
return lax.fori_loop(0, n, lambda i, acc: acc + i, 0)
sum_n(5) # โ 10
body_fun takes (i, state) โ the loop index first, the carry
second โ and returns the new carry. The index i is available if you
need it, but you are free to ignore it.
Common pitfalls
-
Swapped signature. The body is
body(i, state), NOTbody(state, i). Accidentally writinglambda acc, i: acc * xpasses the carry asiand the index asacc, causing silent wrong answers or shape errors. -
Not casting n. Plain Python floats work for the bounds, but if
narrives as a JAX float scalar (e.g. from a traced argument), wrap it withjnp.int32(n)before passing as the upper bound. -
Forgetting init. The accumulator must be initialised. For x^n the
multiplicative identity is
1.0โ starting at0.0gives zero for all inputs.
Problem
Implement fori_power(x, n) that computes x^n using exactly n repeated
multiplications via lax.fori_loop.
-
x: scalar. -
n: scalar (cast tojnp.int32internally). -
Returns: scalar โ
x^n.
Hints
Sign in to attempt this problem and view the solution.