hard primitives

Importance Sampling

Why this matters

Importance sampling is the foundational technique for computing expectations under a distribution p when we can only draw samples from a proposal q. The identity E_p[h(z)] = E_q[h(z) * p(z)/q(z)] lets us reweight samples from q to approximate expectations under p. This is the backbone of variational inference, off-policy reinforcement learning, sequential Monte Carlo, and annealed importance sampling.

Worked mini-example

N = 4, log_p = log_q (so all weights = 1), h = [1, 2, 3, 4].

log_weights = log_p - log_q = [0, 0, 0, 0]
weights = exp(log_weights) = [1, 1, 1, 1]
estimate = mean(h * weights) = mean([1, 2, 3, 4]) = 2.5

When p = q the estimator reduces to the plain Monte Carlo average. โœ“

Common pitfalls

  • Forgetting the (1/N) normalization: the self-normalised estimator sum(h * w) / sum(w) is also valid but this problem uses the unnormalised form mean(h * w) โ€” make sure your q is a proper distribution (integrates to 1) or the two forms diverge.
  • Working in log-space then exping late: compute log_weights first and then exp once โ€” avoids numerical overflow from computing raw p/q ratios when probabilities are tiny.
  • High variance when q mismatches p: if q has thin tails relative to p, a few samples get enormous weights and dominate the estimate. Test 3 illustrates this: log_p - log_q = 1 everywhere, so every sample gets weight e โ‰ˆ 2.718.
  • No sampling required here: inputs are pre-evaluated arrays log_p_at, log_q_at, h_at โ€” the function only computes the weighted average. This keeps the test contract deterministic.

Problem

Implement importance_estimator(log_p_at, log_q_at, h_at):

  • log_p_at โ€” 1-D float32 array (N,) โ€” log p evaluated at sample points
  • log_q_at โ€” 1-D float32 array (N,) โ€” log q evaluated at sample points
  • h_at โ€” 1-D float32 array (N,) โ€” h evaluated at sample points

Return a scalar: (1/N) * sum(h(z_i) * exp(log_p(z_i) - log_q(z_i))).

Hints

jax importance-sampling weights

Sign in to attempt this problem and view the solution.