We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
stop_gradient: Target Network
Why this matters
In Q-learning, the TD (Temporal Difference) target is computed with a separate target network that is updated slowly (e.g., every N steps). During a gradient step on the online network, gradients should not flow through the target โ it acts as a fixed reference, not a trainable output.
lax.stop_gradient blocks autodiff through the wrapped value while leaving
the forward pass completely unchanged. Without it, gradients would flow
into the target network parameters, destabilizing training.
stop_gradient appears in:
- DQN and its variants (Double DQN, Dueling DQN, Rainbow).
- Actor-critic methods (DDPG, TD3, SAC target critics).
- Distillation / knowledge transfer (teacher is a fixed reference).
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
q_online = jnp.array([1.0, 2.0])
q_target = jnp.array([3.0, 4.0])
reward, gamma, action = 1.0, 0.9, 0
# WITHOUT stop_gradient โ gradient flows into q_target (wrong)
target_bad = reward + gamma * jnp.max(q_target)
loss_bad = (target_bad - q_online[action]) ** 2
# grad wrt q_target is non-zero โ wrong, it's a fixed reference
# WITH stop_gradient โ gradient is blocked (correct)
target = reward + gamma * lax.stop_gradient(jnp.max(q_target))
loss = (target - q_online[action]) ** 2
# grad wrt q_target is 0 โ correct, no gradient through target
print(loss) # scalar TD loss
Common pitfalls
-
stop_gradientonly affects autodiff โ the forward value is identical with or without it. The distinction matters only underjax.grad/jax.vjp. -
Stops gradient through the WRAPPED value only โ
stop_gradient(jnp.max(q_target))stops gradient through that scalar; anything else in the expression still participates normally. -
action_takenarrives as float โ JAX delivers scalar inputs asfloat32. Index withq_online[jnp.int32(action_taken)]to avoid a type error. -
Common mistake: wrapping
q_onlineinstead ofq_targetkills the training signal entirely โ the online network would receive zero gradient.
Problem
Implement td_loss(q_online, q_target, action_taken, reward, gamma) that
computes the one-step TD error, blocking gradients through the target bootstrap.
-
q_online: 1-D jax array(n_actions,). -
q_target: 1-D jax array(n_actions,). -
action_taken: scalar delivered as float; cast to int for indexing. -
reward: scalar. -
gamma: scalar discount factor.
Returns: scalar โ (reward + gamma * stop_gradient(max(q_target)) - q_online[action])ยฒ.
Hints
Sign in to attempt this problem and view the solution.