hard primitives

HMC Leapfrog Step

Why this matters

Hamiltonian Monte Carlo (HMC) avoids the random-walk behaviour of vanilla MH by augmenting the state space with a momentum variable and simulating Hamiltonian dynamics. The leapfrog integrator is the symplectic numerical integrator used to evolve these dynamics: it alternates half-step momentum updates with full-step position updates, preserving the symplectic structure (volume preservation) required for detailed balance. HMC with leapfrog is the engine behind NUTS (No-U-Turn Sampler), which powers Stan, PyMC, and NumPyro.

Worked mini-example

One leapfrog step with constant grad = 0 (free particle), step_size = 0.1:

x, p = 0.0, 1.0; grad = 0.0; step = 0.1

# Half-step p
p_half = p + 0.5 * step * grad   # 1.0 + 0 = 1.0
# Full-step x
x_new  = x + step * p_half        # 0.0 + 0.1 * 1.0 = 0.1
# Half-step p again
p_new  = p_half + 0.5 * step * grad  # 1.0 + 0 = 1.0
# โ†’ (0.1, 1.0)

After 5 steps: x = 0.5, p = 1.0 (constant velocity, zero gradient).

Common pitfalls

  • Missing a half-step: each leapfrog iteration has two half-step momentum updates โ€” one at the beginning and one at the end. Skipping either breaks the symplectic structure and the chain will not leave the target invariant.
  • Use lax.fori_loop, not a Python for-loop: a Python loop would unroll n_steps iterations at trace time, which fails when n_steps is a JAX integer. lax.fori_loop keeps the loop as a single XLA op.
  • Cast n_steps to int32: lax.fori_loop(0, jnp.int32(n_steps), ...). Passing a float will raise a type error.
  • Constant grad assumption: this implementation treats grad_log_prob as a scalar constant (quadratic potential / Gaussian target). Real HMC recomputes the gradient at each step via jax.grad.

Problem

Implement leapfrog_step(x, p, grad_log_prob, step_size, n_steps) that runs n_steps iterations of the leapfrog integrator with constant gradient.

All inputs are scalars (floats). Return a 1-D float32 array of shape (2,) โ€” [final_x, final_p].

Hints

jax hmc leapfrog

Sign in to attempt this problem and view the solution.