We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vmap over lax.scan
Why this matters
lax.scan is JAXโs functional loop for sequential state โ the idiomatic way
to implement RNNs, cumulative operations, and any computation where each step
depends on the previous. But in practice, you need to run that loop over a
batch of inputs simultaneously.
The solution is vmap(per_row): wrap the scan in a helper that operates on
a single row, then vmap that helper across the batch dimension. Each vmapped
call runs its own independent scan; XLA can parallelise all of them on
hardware. This composition โ vmap-of-scan โ is the foundation of batched RNN
training in JAX.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
def running_max_batch(x_batch):
"""Cumulative max for each row."""
def per_row(row):
def step(carry, x):
new = jnp.maximum(carry, x)
return new, new
_, ys = lax.scan(step, -jnp.inf, row)
return ys
return jax.vmap(per_row)(x_batch)
running_max_batch(jnp.array([[3., 1., 4.], [1., 5., 2.]]))
# โ [[3., 3., 4.], [1., 5., 5.]]
lax.scan(step, init, xs) returns (final_carry, stacked_outputs). The
second element contains the stacked y values emitted at each step.
Common pitfalls
-
Defining scan body outside per_row. If the step function is defined at
module level and captures nothing, this works โ but the standard idiom is to
nest it inside
per_rowfor clarity and to allow closure over per-row state. -
Vmapping lax.scan directly.
jax.vmap(lax.scan, ...)does not work becausevmaptraces functions, not calls. Always wrap in a helper first. -
Forgetting to unpack outputs.
lax.scanreturns(final_carry, ys). Discard the carry with_when you only need the per-step outputs.
Problem
Implement batched_running_sum(x_batch) that returns the per-row cumulative
sum of a 2-D input using vmap of lax.scan.
-
x_batch: 2-D JAX array of shape(N, T). -
Returns: 2-D array of shape
(N, T)โout[i, t] = sum(x_batch[i, :t+1]).
Hints
Sign in to attempt this problem and view the solution.