easy primitives

stop_gradient: Target Network

Why this matters

In Q-learning, the TD (Temporal Difference) target is computed with a separate target network that is updated slowly (e.g., every N steps). During a gradient step on the online network, gradients should not flow through the target โ€” it acts as a fixed reference, not a trainable output.

lax.stop_gradient blocks autodiff through the wrapped value while leaving the forward pass completely unchanged. Without it, gradients would flow into the target network parameters, destabilizing training.

stop_gradient appears in:

  • DQN and its variants (Double DQN, Dueling DQN, Rainbow).
  • Actor-critic methods (DDPG, TD3, SAC target critics).
  • Distillation / knowledge transfer (teacher is a fixed reference).

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

q_online = jnp.array([1.0, 2.0])
q_target = jnp.array([3.0, 4.0])
reward, gamma, action = 1.0, 0.9, 0

# WITHOUT stop_gradient โ€” gradient flows into q_target (wrong)
target_bad = reward + gamma * jnp.max(q_target)
loss_bad = (target_bad - q_online[action]) ** 2
# grad wrt q_target is non-zero โ€” wrong, it's a fixed reference

# WITH stop_gradient โ€” gradient is blocked (correct)
target = reward + gamma * lax.stop_gradient(jnp.max(q_target))
loss = (target - q_online[action]) ** 2
# grad wrt q_target is 0 โ€” correct, no gradient through target
print(loss)  # scalar TD loss

Common pitfalls

  • stop_gradient only affects autodiff โ€” the forward value is identical with or without it. The distinction matters only under jax.grad / jax.vjp.
  • Stops gradient through the WRAPPED value only โ€” stop_gradient(jnp.max(q_target)) stops gradient through that scalar; anything else in the expression still participates normally.
  • action_taken arrives as float โ€” JAX delivers scalar inputs as float32. Index with q_online[jnp.int32(action_taken)] to avoid a type error.
  • Common mistake: wrapping q_online instead of q_target kills the training signal entirely โ€” the online network would receive zero gradient.

Problem

Implement td_loss(q_online, q_target, action_taken, reward, gamma) that computes the one-step TD error, blocking gradients through the target bootstrap.

  • q_online: 1-D jax array (n_actions,).
  • q_target: 1-D jax array (n_actions,).
  • action_taken: scalar delivered as float; cast to int for indexing.
  • reward: scalar.
  • gamma: scalar discount factor.

Returns: scalar โ€” (reward + gamma * stop_gradient(max(q_target)) - q_online[action])ยฒ.

Hints

jax stop-gradient rl

Sign in to attempt this problem and view the solution.