hard primitives

REINFORCE with Baseline

Why this matters

Vanilla REINFORCE has notoriously high variance: different random samples produce wildly different gradient estimates. Control variates (baselines) are the standard fix. Subtracting any constant b from rewards leaves the expected gradient unchanged โ€” because E[โˆ‡ log p(a)] = 0 for any normalised distribution โ€” but can drastically reduce variance by centring rewards near zero. In practice, b = V(s) (the value function) is used in actor-critic algorithms. This identity underpins PPO, A3C, and every modern policy-gradient method.

Worked mini-example

K = 2, logits = [0,0], rewards = [10, 0], baseline b = 5. Adjusted rewards: [10โˆ’5, 0โˆ’5] = [5, โˆ’5]. E[grad] โ‰ˆ 5ยท[0.5,โˆ’0.5]ยท0.5 + (โˆ’5)ยท[โˆ’0.5,0.5]ยท0.5 = [2.5, โˆ’2.5]. Compare no-baseline: same expected gradient [2.5, โˆ’2.5] โœ“ โ€” unbiasedness. But variance is halved because magnitudes are centred.

Common pitfalls

  • Baseline must NOT depend on the action: if b = b(a), the identity E[b(a) โˆ‡ log p(a)] โ‰  0 in general, and the estimator becomes biased. A state-dependent baseline b = V(s) is fine because the state is fixed before the action is sampled.
  • Only one line changes from vanilla REINFORCE: subtract baseline before weighting โ€” rewards = reward_table[actions] - baseline.
  • Optimal constant baseline: the variance-minimising constant is E[R] (the mean reward), which is why the sample mean is a common choice.
  • Compare to test 3 (baseline=0): setting b = 0 should reproduce vanilla REINFORCE exactly.

Problem

Implement reinforce_with_baseline(seed, logits, reward_table, baseline, n_samples) โ€” identical to vanilla REINFORCE but with (reward โˆ’ baseline) as the weight.

  • seed (float) โ†’ jax.random.PRNGKey(int(seed))
  • logits โ€” 1-D float32 array of length K
  • reward_table โ€” 1-D float32 array of length K
  • baseline โ€” scalar float subtracted from every reward
  • n_samples (float, cast to int) โ€” number of MC samples

Return a 1-D float32 array of shape (K,).

Hints

jax control-variate baseline

Sign in to attempt this problem and view the solution.