hard primitives

vmap over lax.scan

Why this matters

lax.scan is JAXโ€™s functional loop for sequential state โ€” the idiomatic way to implement RNNs, cumulative operations, and any computation where each step depends on the previous. But in practice, you need to run that loop over a batch of inputs simultaneously.

The solution is vmap(per_row): wrap the scan in a helper that operates on a single row, then vmap that helper across the batch dimension. Each vmapped call runs its own independent scan; XLA can parallelise all of them on hardware. This composition โ€” vmap-of-scan โ€” is the foundation of batched RNN training in JAX.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

def running_max_batch(x_batch):
    """Cumulative max for each row."""
    def per_row(row):
        def step(carry, x):
            new = jnp.maximum(carry, x)
            return new, new
        _, ys = lax.scan(step, -jnp.inf, row)
        return ys
    return jax.vmap(per_row)(x_batch)

running_max_batch(jnp.array([[3., 1., 4.], [1., 5., 2.]]))
# โ†’ [[3., 3., 4.], [1., 5., 5.]]

lax.scan(step, init, xs) returns (final_carry, stacked_outputs). The second element contains the stacked y values emitted at each step.

Common pitfalls

  • Defining scan body outside per_row. If the step function is defined at module level and captures nothing, this works โ€” but the standard idiom is to nest it inside per_row for clarity and to allow closure over per-row state.
  • Vmapping lax.scan directly. jax.vmap(lax.scan, ...) does not work because vmap traces functions, not calls. Always wrap in a helper first.
  • Forgetting to unpack outputs. lax.scan returns (final_carry, ys). Discard the carry with _ when you only need the per-step outputs.

Problem

Implement batched_running_sum(x_batch) that returns the per-row cumulative sum of a 2-D input using vmap of lax.scan.

  • x_batch: 2-D JAX array of shape (N, T).
  • Returns: 2-D array of shape (N, T) โ€” out[i, t] = sum(x_batch[i, :t+1]).

Hints

jax vmap scan composition

Sign in to attempt this problem and view the solution.