medium primitives

REINFORCE Gradient Estimator

Why this matters

The REINFORCE estimator (Williams 1992) is the canonical score-function gradient estimator: it expresses the gradient of an expectation as another expectation, ∇θ E[R(a)] = E[R(a) ∇θ log p(a|θ)], enabling gradient-based optimisation even when the reward R is non-differentiable. REINFORCE is the foundation of policy-gradient methods in reinforcement learning (e.g. PPO, TRPO) and variational inference with non-reparameterisable posteriors.

Worked mini-example

K = 2 actions, logits = [0, 0], so p = [0.5, 0.5]. grad_logits log_softmax(logits)[a] = one_hot(a, 2) - softmax(logits). For action 0: [1,0] - [0.5,0.5] = [0.5, -0.5]. For action 1: [0,1] - [0.5,0.5] = [-0.5, 0.5]. With rewards [10, 0] and 50 % samples on each action: E[R ∇ log p] ≈ 10·[0.5,−0.5]·0.5 + 0·[−0.5,0.5]·0.5 = [2.5, −2.5].

Common pitfalls

  • Integer actions break jax.grad: jax.random.categorical returns int32. You cannot call jax.grad on a function that receives an int. The gradient must be taken w.r.t. logits; write jax.grad(lambda l: jax.nn.log_softmax(l)[a])(logits) instead.
  • Raw logits vs log_softmax: log p(a) is log_softmax(logits)[a], not logits[a]. Using raw logits yields un-normalised scores and a wrong gradient.
  • vmap over samples: compute one score-function vector per sampled action, then average. jax.vmap(score_fn)(actions) vectorises this.
  • High variance: REINFORCE is unbiased but has high variance — that is why baselines (next problem) and reparameterisation (previous problem) are so important in practice.

Problem

Implement reinforce_grad(seed, logits, reward_table, n_samples) that estimates ∇_logits E[reward_table[action]] via REINFORCE.

  • seed (float) → jax.random.PRNGKey(int(seed))
  • logits — 1-D float32 array of length K (unnormalised log-probs)
  • reward_table — 1-D float32 array of length K
  • n_samples (float, cast to int) — number of Monte Carlo samples

Return a 1-D float32 array of shape (K,) — the estimated gradient.

Hints

jax reinforce score-function

Sign in to attempt this problem and view the solution.