medium primitives

Fix: Python for โ†’ vmap

Why this matters

Python for loops over JAX arrays unroll at trace time. Each loop iteration adds a copy of the loop body to the jaxpr โ€” so a length-N array produces N copies of the body in the compiled program. This causes slow compilation, large memory usage at compile time, and bloated XLA graphs, all without any runtime speedup.

The BROKEN pattern:

@jax.jit
def broken_sum_squares(x):
    total = 0
    for xi in x:          # unrolls N times in the jaxpr!
        total += xi ** 2
    return total

Worked mini-example

import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])

# BROKEN: Python loop unrolls at trace time โ€” 3 copies of the body
# total = 0
# for xi in x:
#     total += xi ** 2

# FIXED: single vectorized op โ€” one node in the jaxpr
out = jnp.sum(x ** 2)
# โ†’ 14.0

Common pitfalls

  • Python loops over JAX arrays sometimes silently work โ€” if the array length is statically known and small, JAX traces through them without error; the bug is the compilation cost, not a runtime crash.
  • For traced loop counts use lax.scan or lax.fori_loop โ€” when the number of iterations is itself a traced value, Python for cannot be used at all; use jax.lax.scan (for sequences) or jax.lax.fori_loop (for counted loops).
  • jax.vmap for per-element operations โ€” if you need a custom per- element function f, use jax.vmap(f)(x) instead of a loop.
  • `jnp.sum(x 2)` is a single fused kernel** โ€” far more efficient than accumulating a Python scalar in a loop.

Problem

Implement safe_loop_sum_squares(x) that returns the sum of squares of all elements of x using a vectorized JAX operation โ€” no Python loop.

  • x: 1-D JAX array.

Returns: scalar โ€” sum(x^2).

Hints

jax vmap fix-broken

Sign in to attempt this problem and view the solution.