We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
HMC Leapfrog Step
Why this matters
Hamiltonian Monte Carlo (HMC) avoids the random-walk behaviour of vanilla MH by augmenting the state space with a momentum variable and simulating Hamiltonian dynamics. The leapfrog integrator is the symplectic numerical integrator used to evolve these dynamics: it alternates half-step momentum updates with full-step position updates, preserving the symplectic structure (volume preservation) required for detailed balance. HMC with leapfrog is the engine behind NUTS (No-U-Turn Sampler), which powers Stan, PyMC, and NumPyro.
Worked mini-example
One leapfrog step with constant grad = 0 (free particle), step_size = 0.1:
x, p = 0.0, 1.0; grad = 0.0; step = 0.1
# Half-step p
p_half = p + 0.5 * step * grad # 1.0 + 0 = 1.0
# Full-step x
x_new = x + step * p_half # 0.0 + 0.1 * 1.0 = 0.1
# Half-step p again
p_new = p_half + 0.5 * step * grad # 1.0 + 0 = 1.0
# โ (0.1, 1.0)
After 5 steps: x = 0.5, p = 1.0 (constant velocity, zero gradient).
Common pitfalls
- Missing a half-step: each leapfrog iteration has two half-step momentum updates โ one at the beginning and one at the end. Skipping either breaks the symplectic structure and the chain will not leave the target invariant.
-
Use
lax.fori_loop, not a Python for-loop: a Python loop would unroll n_steps iterations at trace time, which fails when n_steps is a JAX integer.lax.fori_loopkeeps the loop as a single XLA op. -
Cast n_steps to int32:
lax.fori_loop(0, jnp.int32(n_steps), ...). Passing a float will raise a type error. -
Constant grad assumption: this implementation treats
grad_log_probas a scalar constant (quadratic potential / Gaussian target). Real HMC recomputes the gradient at each step viajax.grad.
Problem
Implement leapfrog_step(x, p, grad_log_prob, step_size, n_steps) that runs
n_steps iterations of the leapfrog integrator with constant gradient.
All inputs are scalars (floats). Return a 1-D float32 array of shape (2,)
โ [final_x, final_p].
Hints
Sign in to attempt this problem and view the solution.