We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Fix: Python for โ vmap
Why this matters
Python for loops over JAX arrays unroll at trace time. Each loop
iteration adds a copy of the loop body to the jaxpr โ so a length-N array
produces N copies of the body in the compiled program. This causes slow
compilation, large memory usage at compile time, and bloated XLA graphs,
all without any runtime speedup.
The BROKEN pattern:
@jax.jit
def broken_sum_squares(x):
total = 0
for xi in x: # unrolls N times in the jaxpr!
total += xi ** 2
return total
Worked mini-example
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
# BROKEN: Python loop unrolls at trace time โ 3 copies of the body
# total = 0
# for xi in x:
# total += xi ** 2
# FIXED: single vectorized op โ one node in the jaxpr
out = jnp.sum(x ** 2)
# โ 14.0
Common pitfalls
- Python loops over JAX arrays sometimes silently work โ if the array length is statically known and small, JAX traces through them without error; the bug is the compilation cost, not a runtime crash.
-
For traced loop counts use
lax.scanorlax.fori_loopโ when the number of iterations is itself a traced value, Pythonforcannot be used at all; usejax.lax.scan(for sequences) orjax.lax.fori_loop(for counted loops). -
jax.vmapfor per-element operations โ if you need a custom per- element functionf, usejax.vmap(f)(x)instead of a loop. - `jnp.sum(x 2)` is a single fused kernel** โ far more efficient than accumulating a Python scalar in a loop.
Problem
Implement safe_loop_sum_squares(x) that returns the sum of squares of
all elements of x using a vectorized JAX operation โ no Python loop.
-
x: 1-D JAX array.
Returns: scalar โ sum(x^2).
Hints
jax
vmap
fix-broken
Sign in to attempt this problem and view the solution.