easy primitives

Cumulative Sum via lax.scan

Why this matters

lax.scan is JAX’s fold-with-outputs primitive — the functional equivalent of a for-loop that accumulates state and emits one output per step. It compiles to a single XLA while loop, so it is JIT-friendly by default: no Python-level unrolling, no trace-time blowup for long sequences.

Cumulative sum is the simplest possible scan: carry is the running total, output is the same running total. Mastering this one-liner before moving on to tuple carries, RNN steps, or layer stacks pays off quickly.

Worked mini-example

import jax
from jax import lax

def step(carry, x_i):
    new_carry = carry + x_i
    return new_carry, new_carry   # (new_carry, output)

_, ys = lax.scan(step, 0.0, jax.numpy.array([1.0, 2.0, 3.0]))
# ys → [1.0, 3.0, 6.0]

lax.scan(body, init, xs) returns (final_carry, stacked_outputs). The carry threads state forward; the second element of the body’s return tuple is stacked into the output array.

Common pitfalls

  • Returning a single value from body. The body MUST return a 2-tuple (new_carry, output). Returning just new_carry raises a shape error.
  • Wrong init dtype. If x is float32, init=0 (int) can cause a dtype mismatch. Use 0.0 or jnp.zeros((), dtype=x.dtype).
  • Confusing the return order. lax.scan returns (final_carry, outputs) — drop the carry with _ if you only need the per-step outputs.

Problem

Implement cumulative_sum(x) using lax.scan. Do not use jnp.cumsum — the point is to learn the scan API.

  • x: 1-D jax array.
  • Returns: 1-D array, same shape. out[i] = sum(x[:i+1]).

Hints

jax scan fold

Sign in to attempt this problem and view the solution.