We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Cumulative Sum via lax.scan
Why this matters
lax.scan is JAX’s fold-with-outputs primitive — the functional
equivalent of a for-loop that accumulates state and emits one output per
step. It compiles to a single XLA while loop, so it is JIT-friendly by
default: no Python-level unrolling, no trace-time blowup for long
sequences.
Cumulative sum is the simplest possible scan: carry is the running total, output is the same running total. Mastering this one-liner before moving on to tuple carries, RNN steps, or layer stacks pays off quickly.
Worked mini-example
import jax
from jax import lax
def step(carry, x_i):
new_carry = carry + x_i
return new_carry, new_carry # (new_carry, output)
_, ys = lax.scan(step, 0.0, jax.numpy.array([1.0, 2.0, 3.0]))
# ys → [1.0, 3.0, 6.0]
lax.scan(body, init, xs) returns (final_carry, stacked_outputs).
The carry threads state forward; the second element of the body’s return
tuple is stacked into the output array.
Common pitfalls
-
Returning a single value from body. The body MUST return a 2-tuple
(new_carry, output). Returning justnew_carryraises a shape error. -
Wrong init dtype. If
xis float32,init=0(int) can cause a dtype mismatch. Use0.0orjnp.zeros((), dtype=x.dtype). -
Confusing the return order.
lax.scanreturns(final_carry, outputs)— drop the carry with_if you only need the per-step outputs.
Problem
Implement cumulative_sum(x) using lax.scan. Do not use
jnp.cumsum — the point is to learn the scan API.
-
x: 1-D jax array. -
Returns: 1-D array, same shape.
out[i] = sum(x[:i+1]).
Hints
Sign in to attempt this problem and view the solution.