We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
REINFORCE with Baseline
Why this matters
Vanilla REINFORCE has notoriously high variance: different random samples
produce wildly different gradient estimates. Control variates (baselines)
are the standard fix. Subtracting any constant b from rewards leaves the
expected gradient unchanged โ because E[โ log p(a)] = 0 for any normalised
distribution โ but can drastically reduce variance by centring rewards near
zero. In practice, b = V(s) (the value function) is used in actor-critic
algorithms. This identity underpins PPO, A3C, and every modern policy-gradient
method.
Worked mini-example
K = 2, logits = [0,0], rewards = [10, 0], baseline b = 5. Adjusted rewards: [10โ5, 0โ5] = [5, โ5]. E[grad] โ 5ยท[0.5,โ0.5]ยท0.5 + (โ5)ยท[โ0.5,0.5]ยท0.5 = [2.5, โ2.5]. Compare no-baseline: same expected gradient [2.5, โ2.5] โ โ unbiasedness. But variance is halved because magnitudes are centred.
Common pitfalls
- Baseline must NOT depend on the action: if b = b(a), the identity E[b(a) โ log p(a)] โ 0 in general, and the estimator becomes biased. A state-dependent baseline b = V(s) is fine because the state is fixed before the action is sampled.
-
Only one line changes from vanilla REINFORCE: subtract baseline
before weighting โ
rewards = reward_table[actions] - baseline. - Optimal constant baseline: the variance-minimising constant is E[R] (the mean reward), which is why the sample mean is a common choice.
- Compare to test 3 (baseline=0): setting b = 0 should reproduce vanilla REINFORCE exactly.
Problem
Implement reinforce_with_baseline(seed, logits, reward_table, baseline, n_samples)
โ identical to vanilla REINFORCE but with (reward โ baseline) as the weight.
-
seed(float) โjax.random.PRNGKey(int(seed)) -
logitsโ 1-D float32 array of length K -
reward_tableโ 1-D float32 array of length K -
baselineโ scalar float subtracted from every reward -
n_samples(float, cast to int) โ number of MC samples
Return a 1-D float32 array of shape (K,).
Hints
Sign in to attempt this problem and view the solution.