We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Associative Scan (Parallel Cumsum)
Why this matters
lax.associative_scan(op, x) is JAXโs parallel scan primitive. Unlike
lax.scan which runs sequentially (O(n) steps, O(1) parallelism),
associative_scan exploits the associativity of op to achieve O(log n)
depth on hardware that supports it โ the classic work-efficient parallel
prefix sum algorithm.
This matters in practice for long-sequence operations that can be expressed as associative reductions: cumulative sum, cumulative product, running max, and โ most importantly in deep learning โ parallel RNN/SSM state updates (e.g., Mamba, S4, linear attention).
Worked mini-example
from jax import lax
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0, 4.0])
out = lax.associative_scan(lambda a, b: a + b, x)
# โ [1.0, 3.0, 6.0, 10.0]
The output out[i] = op(x[0], x[1], ..., x[i]) with associativity
guaranteeing the result matches a sequential scan.
Common pitfalls
-
Non-associative ops. The op MUST satisfy
(a โ b) โ c = a โ (b โ c). Division, subtraction, and most non-linear ops are not associative and will give wrong results (no runtime error โ just silent incorrect output). - Floating-point associativity. For floats, addition is associative within numerical precision; this is generally fine in practice.
-
Axis argument.
lax.associative_scanaccepts an optionalaxisargument (default 0) to scan along a different axis of a 2-D+ input. - Shape. Output shape is identical to input shape.
Problem
Implement parallel_cumsum(x) using lax.associative_scan. Do not
use jnp.cumsum or lax.scan โ the point is to learn the parallel API.
-
x: 1-D jax array. -
Returns: 1-D array same shape.
out[i] = sum(x[:i+1]).
Hints
Sign in to attempt this problem and view the solution.