medium primitives

lax.scan with reverse=True

Why this matters

lax.scan(body, init, xs, reverse=True) runs the scan from end to start โ€” the body is applied to xs[-1] first, then xs[-2], and so on. The output array is re-indexed so that out[i] corresponds to position i of the input (not the scan order).

This is essential for reverse-time accumulations: discounted future returns in reinforcement learning, backward message passing, and any algorithm that needs a suffix-sum structure rather than a prefix-sum.

Worked mini-example

from jax import lax

def step(carry, x_i):
    new = carry + x_i
    return new, new

x = jax.numpy.array([1.0, 2.0, 3.0])
_, ys = lax.scan(step, 0.0, x, reverse=True)
# scan order: 3โ†’2โ†’1 โ†’ ys = [6, 5, 3]
# out[0] = 1+2+3=6, out[1] = 2+3=5, out[2] = 3

The body signature is identical to forward scan: (carry, x_i) โ†’ (new_carry, output). Just add reverse=True to the lax.scan call.

Common pitfalls

  • Forgetting reverse=True. Without it you get a forward cumsum, not a reverse one.
  • Output ordering confusion. out[i] always corresponds to xs[i] in the input โ€” the reversal is an implementation detail, not reflected in the outputโ€™s index order.
  • Init value. The carry starts at init and accumulates from the last element โ€” same semantics as forward scan but from the other end.
  • Carry dtype. Use 0.0 (float) not 0 (int) when accumulating floats to avoid dtype promotion surprises.

Problem

Implement reverse_cumsum(x) using lax.scan with reverse=True. The output should satisfy out[i] = sum(x[i:]) โ€” the suffix sum.

  • x: 1-D jax array.
  • Returns: 1-D array same shape. out[i] = sum(x[i:]).

Hints

jax scan reverse

Sign in to attempt this problem and view the solution.