medium primitives

Metropolis-Hastings Step

Why this matters

The Metropolis-Hastings (MH) algorithm is the building block of most MCMC samplers. At each step it proposes a new state, computes an acceptance ratio ฮฑ = min(1, p(xโ€™)/p(x) ยท q(x|xโ€™)/q(xโ€™|x)), draws a uniform random number, and either moves or stays. For symmetric proposals โ€” e.g. a Gaussian centred on the current position โ€” the proposal density ratio q cancels, and ฮฑ = min(1, p(xโ€™)/p(x)). Working in log-space is essential for numerical stability, especially in high-dimensional models.

Worked mini-example

import jax, jax.numpy as jnp

seed = 0; x = 0.0; log_p_current = -1.0
proposal_x = 1.0; log_p_proposal = -0.5

key = jax.random.PRNGKey(int(seed))
log_alpha = jnp.minimum(0.0, log_p_proposal - log_p_current)
# log_alpha = min(0, -0.5 - (-1.0)) = min(0, 0.5) = 0.0
u = jax.random.uniform(key)            # u โ‰ˆ 0.0
accepted = jnp.log(u) < log_alpha      # -large < 0.0 โ†’ True
new_x = jnp.where(accepted, proposal_x, x)  # 1.0
return jnp.array([new_x, 1.0])        # [1.0, 1.0]

Common pitfalls

  • Compare log(u) < log_alpha, not u < exp(log_alpha): when log_p_proposal > log_p_current, log_alpha = 0 and exp(log_alpha) = 1, so u < 1 is always True โ€” correct. But computing exp(log_alpha) first is numerically fragile and unnecessary since jnp.minimum(0, ...) already clamps the log ratio.
  • Always accept when proposal is better: log_alpha = min(0, positive) = 0, so log(u) < 0 for any u โˆˆ (0,1) โ€” the proposal is always accepted.
  • Return format: pack [new_x, accepted_flag] as a 1-D float32 array.

Problem

Implement mh_step(seed, x, log_p_current, proposal_x, log_p_proposal) for a single Metropolis-Hastings step with symmetric proposal.

seed is a float (cast to int for PRNGKey). All other inputs are scalars. Return a 1-D float32 array of shape (2,) โ€” [new_x, accepted_flag] where accepted_flag is 1.0 if the proposal was accepted and 0.0 otherwise.

Hints

jax mcmc metropolis

Sign in to attempt this problem and view the solution.