We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Running Mean via lax.scan
Why this matters
In the previous problem the carry was a single scalar. Real use cases โ
RNN hidden states, optimiser moment estimates, online statistics โ need
carries that are tuples or pytrees. JAXโs lax.scan handles any
pytree as the carry: the shape and structure of init must simply match
what the body returns as its first element.
Running mean is the minimal exercise for tuple carries: you need both a
running sum and a count, so the carry is (sum, count). Mastering this
pattern unlocks full RNN training with scan.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def step(carry, x_i):
s, n = carry
new_s = s + x_i
new_n = n + 1
return (new_s, new_n), new_s / new_n
_, means = lax.scan(step, (0.0, 0), jnp.array([4.0, 2.0]))
# means โ [4.0, 3.0]
The carry is a 2-tuple; init is (0.0, 0) โ sum starts at 0.0,
count starts at 0.
Common pitfalls
-
Mismatched carry structure.
initand the first return value of the body must have the same pytree structure and leaf shapes. If you return(new_s, new_n)but init is(0.0, 0.0)with a float count, JAX may silently promote or error depending on version. -
Updating only one field. Forgetting to increment
ngives a constant mean equal to the first element. -
Integer division.
nas a plain Python int works for small counts but preferjnp.int32init ((0.0, jnp.int32(0))) for JIT safety at scale โ plain0auto-promotes fine in practice.
Problem
Implement running_mean(x) using lax.scan with a tuple carry
(running_sum, count). Do not use jnp.cumsum or Python loops.
-
x: 1-D jax array. -
Returns: 1-D array, same shape.
out[i] = mean(x[:i+1]).
Hints
Sign in to attempt this problem and view the solution.