We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Grad through stop_gradient
Why this matters
lax.stop_gradient(x) blocks autodiff from flowing through x. From
the perspective of the differentiation engine, the wrapped value is treated
as a constant โ even though the forward computation is unchanged.
This is a foundational pattern in JAX because:
-
Target networks (DQN, DDPG) โ the target Q-value should not receive
gradients; wrapping it with
stop_gradientenforces this cleanly. - Straight-through estimators โ let a discrete forward pass influence the backward via a continuous surrogate, but stop the gradient from flowing back through the discrete op.
-
Freezing sub-networks โ treat a pretrained encoder as a feature
extractor by
stop_gradient-ing its output before computing the loss.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
x = jnp.array([1.0, 2.0])
w = 3.0
def loss(w):
return jnp.sum((lax.stop_gradient(x) * w) ** 2)
g = jax.grad(loss)(w)
# grad w.r.t. w = 2 * w * sum(xยฒ) = 2 * 3 * 5 = 30.0
# grad w.r.t. x would be 0 if we tried (stop_gradient blocks it)
Common pitfalls
-
Only blocks gradient โ not the forward value โ
stop_gradient(x)still returnsxin the forward pass; only autodiff is blocked. -
Wrong target โ wrapping the wrong argument (e.g., wrapping
winstead ofx) blocks all gradient and defeats the purpose. -
Alternative:
jax.lax.stop_gradientโ accessible as eitherjax.lax.stop_gradientorfrom jax import lax; lax.stop_gradient.
Problem
Implement grad_through_stop(x, w) that:
-
Defines
loss(w) = sum((stop_gradient(x) * w)ยฒ). -
Returns
jax.grad(loss)(w)โ the gradient w.r.t.wonly.
-
x: 1-D jax array. -
w: scalar.
Returns: scalar โ equals 2 * w * sum(xยฒ).
Hints
jax
stop-gradient
purity
Sign in to attempt this problem and view the solution.