We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Random Walk via lax.scan
Why this matters
Sequential stochastic processes โ random walks, MCMC chains, sequential
Monte Carlo โ require threading a PRNG key through a loop while accumulating
outputs. lax.scan is JAXโs functional, JIT-friendly way to do this.
The pattern: carry (key, state), split the key at every step, draw a
new sample, and advance the state. This appears in particle filters,
diffusion model trajectory sampling, and any recurrent process with
stochastic inputs.
Worked mini-example
n_steps = 2, step_std = 1.0, seed = 0.
carry0 = (PRNGKey(0), 0.0)
step 1: split key โ delta ~ N(0,1); new_pos = 0 + delta; carry = (new_key, new_pos)
step 2: split new_key โ delta ~ N(0,1); new_pos += delta
outputs: [pos_after_step1, pos_after_step2]
Initial position 0 is excluded from the output; only the n_steps
post-step positions are returned.
Common pitfalls
- Not splitting the key per step: reusing the same key in every iteration produces the same delta every step โ a correlated walk, not a random one.
-
Dummy xs:
lax.scanrequires anxsargument to know how many steps to run. Usejnp.zeros(n)as a length-n dummy; the scan body ignores the per-step value with_. -
Carry tuple structure: carry must be
(key, position)โ a tuple, not a flat list โ so that JAX can trace the shapes correctly. -
Return value: scan returns
(final_carry, stacked_outputs). Unpack as(_, _), positions = lax.scan(...).
Problem
Implement random_walk(seed, n_steps, step_std):
-
seed(float) โjax.random.PRNGKey(int(seed)) -
n_steps(float, cast to int) โ number of steps to simulate -
step_std(float) โ standard deviation of each Gaussian step
Return a 1-D float32 array of shape (n_steps,) containing the
position after each step (initial position 0 excluded).
Hints
Sign in to attempt this problem and view the solution.