medium primitives

jit of scan

Why this matters

In production JAX code, training loops are not written as Python for loops — they are compiled once with jit applied to a function that uses lax.scan internally. The combination jit(scan_fn) traces the entire sequential loop into a single XLA program, eliminating Python dispatch overhead on every step.

This is the pattern behind optax update loops, recurrent model training, and any scenario where you need both compilation and sequential state. The first call compiles; subsequent calls reuse the cached executable and run at full hardware speed.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

@jax.jit
def scan_product(x):
    """Running product via scan, compiled with jit."""
    def step(carry, xi):
        return carry * xi, carry * xi
    _, ys = lax.scan(step, 1.0, x)
    return ys

scan_product(jnp.array([1.0, 2.0, 3.0, 4.0]))
# → [1., 2., 6., 24.]  (compiled on first call)

@jax.jit decorates the function definition; the function is compiled lazily on its first call. You can also apply jit inline: f = jax.jit(my_fn).

Common pitfalls

  • jit-ing the call, not the function. jax.jit(my_fn(x)) evaluates my_fn(x) eagerly and then tries to jit the result, not the function. Always write jax.jit(my_fn)(x) or decorate my_fn with @jax.jit.
  • scan body not pure. Any Python side effect inside step (print, list append) runs only at trace time. It will not execute on subsequent jit’d calls. Keep scan bodies pure.
  • Confusing carry and outputs. lax.scan returns (final_carry, stacked_ys). If you only need the final accumulated value, take final_carry (index 0) and emit None as the per-step output.

Problem

Implement jit_scan_sum(x) that uses @jax.jit wrapping a function which internally uses lax.scan to compute the sum of x.

  • x: 1-D JAX array.
  • Returns: scalar — sum of all elements of x.

Hints

jax jit scan composition

Sign in to attempt this problem and view the solution.