We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
lax.scan with reverse=True
Why this matters
lax.scan(body, init, xs, reverse=True) runs the scan from end to
start โ the body is applied to xs[-1] first, then xs[-2], and so on.
The output array is re-indexed so that out[i] corresponds to position i
of the input (not the scan order).
This is essential for reverse-time accumulations: discounted future returns in reinforcement learning, backward message passing, and any algorithm that needs a suffix-sum structure rather than a prefix-sum.
Worked mini-example
from jax import lax
def step(carry, x_i):
new = carry + x_i
return new, new
x = jax.numpy.array([1.0, 2.0, 3.0])
_, ys = lax.scan(step, 0.0, x, reverse=True)
# scan order: 3โ2โ1 โ ys = [6, 5, 3]
# out[0] = 1+2+3=6, out[1] = 2+3=5, out[2] = 3
The body signature is identical to forward scan: (carry, x_i) โ (new_carry, output).
Just add reverse=True to the lax.scan call.
Common pitfalls
-
Forgetting
reverse=True. Without it you get a forward cumsum, not a reverse one. -
Output ordering confusion.
out[i]always corresponds toxs[i]in the input โ the reversal is an implementation detail, not reflected in the outputโs index order. -
Init value. The carry starts at
initand accumulates from the last element โ same semantics as forward scan but from the other end. -
Carry dtype. Use
0.0(float) not0(int) when accumulating floats to avoid dtype promotion surprises.
Problem
Implement reverse_cumsum(x) using lax.scan with reverse=True.
The output should satisfy out[i] = sum(x[i:]) โ the suffix sum.
-
x: 1-D jax array. -
Returns: 1-D array same shape.
out[i] = sum(x[i:]).
Hints
Sign in to attempt this problem and view the solution.