medium primitives

Grad through stop_gradient

Why this matters

lax.stop_gradient(x) blocks autodiff from flowing through x. From the perspective of the differentiation engine, the wrapped value is treated as a constant โ€” even though the forward computation is unchanged.

This is a foundational pattern in JAX because:

  • Target networks (DQN, DDPG) โ€” the target Q-value should not receive gradients; wrapping it with stop_gradient enforces this cleanly.
  • Straight-through estimators โ€” let a discrete forward pass influence the backward via a continuous surrogate, but stop the gradient from flowing back through the discrete op.
  • Freezing sub-networks โ€” treat a pretrained encoder as a feature extractor by stop_gradient-ing its output before computing the loss.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

x = jnp.array([1.0, 2.0])
w = 3.0

def loss(w):
    return jnp.sum((lax.stop_gradient(x) * w) ** 2)

g = jax.grad(loss)(w)
# grad w.r.t. w = 2 * w * sum(xยฒ) = 2 * 3 * 5 = 30.0
# grad w.r.t. x would be 0 if we tried (stop_gradient blocks it)

Common pitfalls

  • Only blocks gradient โ€” not the forward value โ€” stop_gradient(x) still returns x in the forward pass; only autodiff is blocked.
  • Wrong target โ€” wrapping the wrong argument (e.g., wrapping w instead of x) blocks all gradient and defeats the purpose.
  • Alternative: jax.lax.stop_gradient โ€” accessible as either jax.lax.stop_gradient or from jax import lax; lax.stop_gradient.

Problem

Implement grad_through_stop(x, w) that:

  1. Defines loss(w) = sum((stop_gradient(x) * w)ยฒ).
  2. Returns jax.grad(loss)(w) โ€” the gradient w.r.t. w only.
  • x: 1-D jax array.
  • w: scalar.

Returns: scalar โ€” equals 2 * w * sum(xยฒ).

Hints

jax stop-gradient purity

Sign in to attempt this problem and view the solution.