easy primitives

x^n via lax.fori_loop

Why this matters

lax.fori_loop(lower, upper, body_fun, init_val) is JAXโ€™s primitive for bounded-counter loops โ€” loops where the iteration count is a known (concrete) integer at compile time. Unlike lax.while_loop, the termination condition is not data-dependent; you simply supply a lower and upper bound and the loop runs for exactly upper - lower iterations.

This makes it the natural replacement for for i in range(n): โ€ฆ under jit. Python for loops can work under jit when n is a static Python int, but lax.fori_loop is the explicit, XLA-idiomatic way to express the same intent and works correctly when n is a JAX integer.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def sum_n(n):
    # Sum 0..n-1 using fori_loop
    return lax.fori_loop(0, n, lambda i, acc: acc + i, 0)

sum_n(5)   # โ†’ 10

body_fun takes (i, state) โ€” the loop index first, the carry second โ€” and returns the new carry. The index i is available if you need it, but you are free to ignore it.

Common pitfalls

  • Swapped signature. The body is body(i, state), NOT body(state, i). Accidentally writing lambda acc, i: acc * x passes the carry as i and the index as acc, causing silent wrong answers or shape errors.
  • Not casting n. Plain Python floats work for the bounds, but if n arrives as a JAX float scalar (e.g. from a traced argument), wrap it with jnp.int32(n) before passing as the upper bound.
  • Forgetting init. The accumulator must be initialised. For x^n the multiplicative identity is 1.0 โ€” starting at 0.0 gives zero for all inputs.

Problem

Implement fori_power(x, n) that computes x^n using exactly n repeated multiplications via lax.fori_loop.

  • x: scalar.
  • n: scalar (cast to jnp.int32 internally).
  • Returns: scalar โ€” x^n.

Hints

jax fori-loop iteration

Sign in to attempt this problem and view the solution.