We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
REINFORCE Gradient Estimator
Why this matters
The REINFORCE estimator (Williams 1992) is the canonical score-function gradient estimator: it expresses the gradient of an expectation as another expectation, ∇θ E[R(a)] = E[R(a) ∇θ log p(a|θ)], enabling gradient-based optimisation even when the reward R is non-differentiable. REINFORCE is the foundation of policy-gradient methods in reinforcement learning (e.g. PPO, TRPO) and variational inference with non-reparameterisable posteriors.
Worked mini-example
K = 2 actions, logits = [0, 0], so p = [0.5, 0.5].
grad_logits log_softmax(logits)[a] = one_hot(a, 2) - softmax(logits).
For action 0: [1,0] - [0.5,0.5] = [0.5, -0.5].
For action 1: [0,1] - [0.5,0.5] = [-0.5, 0.5].
With rewards [10, 0] and 50 % samples on each action:
E[R ∇ log p] ≈ 10·[0.5,−0.5]·0.5 + 0·[−0.5,0.5]·0.5 = [2.5, −2.5].
Common pitfalls
-
Integer actions break
jax.grad:jax.random.categoricalreturns int32. You cannot calljax.gradon a function that receives an int. The gradient must be taken w.r.t.logits; writejax.grad(lambda l: jax.nn.log_softmax(l)[a])(logits)instead. -
Raw logits vs log_softmax:
log p(a)islog_softmax(logits)[a], notlogits[a]. Using raw logits yields un-normalised scores and a wrong gradient. -
vmap over samples: compute one score-function vector per sampled
action, then average.
jax.vmap(score_fn)(actions)vectorises this. - High variance: REINFORCE is unbiased but has high variance — that is why baselines (next problem) and reparameterisation (previous problem) are so important in practice.
Problem
Implement reinforce_grad(seed, logits, reward_table, n_samples) that
estimates ∇_logits E[reward_table[action]] via REINFORCE.
-
seed(float) →jax.random.PRNGKey(int(seed)) -
logits— 1-D float32 array of length K (unnormalised log-probs) -
reward_table— 1-D float32 array of length K -
n_samples(float, cast to int) — number of Monte Carlo samples
Return a 1-D float32 array of shape (K,) — the estimated gradient.
Hints
Sign in to attempt this problem and view the solution.