hard primitives

Associative Scan (Parallel Cumsum)

Why this matters

lax.associative_scan(op, x) is JAXโ€™s parallel scan primitive. Unlike lax.scan which runs sequentially (O(n) steps, O(1) parallelism), associative_scan exploits the associativity of op to achieve O(log n) depth on hardware that supports it โ€” the classic work-efficient parallel prefix sum algorithm.

This matters in practice for long-sequence operations that can be expressed as associative reductions: cumulative sum, cumulative product, running max, and โ€” most importantly in deep learning โ€” parallel RNN/SSM state updates (e.g., Mamba, S4, linear attention).

Worked mini-example

from jax import lax
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0, 4.0])
out = lax.associative_scan(lambda a, b: a + b, x)
# โ†’ [1.0, 3.0, 6.0, 10.0]

The output out[i] = op(x[0], x[1], ..., x[i]) with associativity guaranteeing the result matches a sequential scan.

Common pitfalls

  • Non-associative ops. The op MUST satisfy (a โŠ• b) โŠ• c = a โŠ• (b โŠ• c). Division, subtraction, and most non-linear ops are not associative and will give wrong results (no runtime error โ€” just silent incorrect output).
  • Floating-point associativity. For floats, addition is associative within numerical precision; this is generally fine in practice.
  • Axis argument. lax.associative_scan accepts an optional axis argument (default 0) to scan along a different axis of a 2-D+ input.
  • Shape. Output shape is identical to input shape.

Problem

Implement parallel_cumsum(x) using lax.associative_scan. Do not use jnp.cumsum or lax.scan โ€” the point is to learn the parallel API.

  • x: 1-D jax array.
  • Returns: 1-D array same shape. out[i] = sum(x[:i+1]).

Hints

jax associative-scan parallel

Sign in to attempt this problem and view the solution.