hard primitives

Random Walk via lax.scan

Why this matters

Sequential stochastic processes โ€” random walks, MCMC chains, sequential Monte Carlo โ€” require threading a PRNG key through a loop while accumulating outputs. lax.scan is JAXโ€™s functional, JIT-friendly way to do this. The pattern: carry (key, state), split the key at every step, draw a new sample, and advance the state. This appears in particle filters, diffusion model trajectory sampling, and any recurrent process with stochastic inputs.

Worked mini-example

n_steps = 2, step_std = 1.0, seed = 0.

carry0 = (PRNGKey(0), 0.0)
step 1: split key โ†’ delta ~ N(0,1); new_pos = 0 + delta; carry = (new_key, new_pos)
step 2: split new_key โ†’ delta ~ N(0,1); new_pos += delta
outputs: [pos_after_step1, pos_after_step2]

Initial position 0 is excluded from the output; only the n_steps post-step positions are returned.

Common pitfalls

  • Not splitting the key per step: reusing the same key in every iteration produces the same delta every step โ€” a correlated walk, not a random one.
  • Dummy xs: lax.scan requires an xs argument to know how many steps to run. Use jnp.zeros(n) as a length-n dummy; the scan body ignores the per-step value with _.
  • Carry tuple structure: carry must be (key, position) โ€” a tuple, not a flat list โ€” so that JAX can trace the shapes correctly.
  • Return value: scan returns (final_carry, stacked_outputs). Unpack as (_, _), positions = lax.scan(...).

Problem

Implement random_walk(seed, n_steps, step_std):

  • seed (float) โ†’ jax.random.PRNGKey(int(seed))
  • n_steps (float, cast to int) โ€” number of steps to simulate
  • step_std (float) โ€” standard deviation of each Gaussian step

Return a 1-D float32 array of shape (n_steps,) containing the position after each step (initial position 0 excluded).

Hints

jax scan random stochastic-process

Sign in to attempt this problem and view the solution.