We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jit of scan
Why this matters
In production JAX code, training loops are not written as Python for loops
— they are compiled once with jit applied to a function that uses
lax.scan internally. The combination jit(scan_fn) traces the entire
sequential loop into a single XLA program, eliminating Python dispatch
overhead on every step.
This is the pattern behind optax update loops, recurrent model training,
and any scenario where you need both compilation and sequential state. The
first call compiles; subsequent calls reuse the cached executable and run
at full hardware speed.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
@jax.jit
def scan_product(x):
"""Running product via scan, compiled with jit."""
def step(carry, xi):
return carry * xi, carry * xi
_, ys = lax.scan(step, 1.0, x)
return ys
scan_product(jnp.array([1.0, 2.0, 3.0, 4.0]))
# → [1., 2., 6., 24.] (compiled on first call)
@jax.jit decorates the function definition; the function is compiled
lazily on its first call. You can also apply jit inline:
f = jax.jit(my_fn).
Common pitfalls
-
jit-ing the call, not the function.
jax.jit(my_fn(x))evaluatesmy_fn(x)eagerly and then tries to jit the result, not the function. Always writejax.jit(my_fn)(x)or decoratemy_fnwith@jax.jit. -
scan body not pure. Any Python side effect inside
step(print, list append) runs only at trace time. It will not execute on subsequent jit’d calls. Keep scan bodies pure. -
Confusing carry and outputs.
lax.scanreturns(final_carry, stacked_ys). If you only need the final accumulated value, takefinal_carry(index 0) and emitNoneas the per-step output.
Problem
Implement jit_scan_sum(x) that uses @jax.jit wrapping a function which
internally uses lax.scan to compute the sum of x.
-
x: 1-D JAX array. -
Returns: scalar — sum of all elements of
x.
Hints
Sign in to attempt this problem and view the solution.