medium primitives

Running Mean via lax.scan

Why this matters

In the previous problem the carry was a single scalar. Real use cases โ€” RNN hidden states, optimiser moment estimates, online statistics โ€” need carries that are tuples or pytrees. JAXโ€™s lax.scan handles any pytree as the carry: the shape and structure of init must simply match what the body returns as its first element.

Running mean is the minimal exercise for tuple carries: you need both a running sum and a count, so the carry is (sum, count). Mastering this pattern unlocks full RNN training with scan.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def step(carry, x_i):
    s, n = carry
    new_s = s + x_i
    new_n = n + 1
    return (new_s, new_n), new_s / new_n

_, means = lax.scan(step, (0.0, 0), jnp.array([4.0, 2.0]))
# means โ†’ [4.0, 3.0]

The carry is a 2-tuple; init is (0.0, 0) โ€” sum starts at 0.0, count starts at 0.

Common pitfalls

  • Mismatched carry structure. init and the first return value of the body must have the same pytree structure and leaf shapes. If you return (new_s, new_n) but init is (0.0, 0.0) with a float count, JAX may silently promote or error depending on version.
  • Updating only one field. Forgetting to increment n gives a constant mean equal to the first element.
  • Integer division. n as a plain Python int works for small counts but prefer jnp.int32 init ((0.0, jnp.int32(0))) for JIT safety at scale โ€” plain 0 auto-promotes fine in practice.

Problem

Implement running_mean(x) using lax.scan with a tuple carry (running_sum, count). Do not use jnp.cumsum or Python loops.

  • x: 1-D jax array.
  • Returns: 1-D array, same shape. out[i] = mean(x[:i+1]).

Hints

jax scan carry

Sign in to attempt this problem and view the solution.