We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Selective stop_gradient (Frozen Layers)
Why this matters
In transfer learning and multi-task models, you often want to freeze
some parameters β let them contribute to the forward pass but receive no
gradient update. lax.stop_gradient is the canonical JAX tool for this.
Wrapping a parameter with lax.stop_gradient(w) tells autodiff: βtreat
this value as a constant during the backward pass.β The forward value is
unchanged; only gradient flow is blocked.
This is the foundation for:
- Frozen encoder layers in fine-tuning LLMs.
- Target networks in RL (soft updates only, no gradient through target).
- Straight-through estimators (stop_gradient for discrete ops).
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
def loss(w_train, w_fixed):
return jnp.sum((w_train + lax.stop_gradient(w_fixed)) ** 2)
# Gradient only flows through w_train, not w_fixed
print(jax.grad(loss)(1.0, 5.0)) # 12.0 (2*(1+5)*1)
For a whole pytree: jax.tree.map(lax.stop_gradient, frozen_params).
Common pitfalls
-
Stopping the wrong parameter: wrapping
w_traininstead ofw_freezeblocks ALL gradient, defeating the purpose. -
Forward value is unchanged:
stop_gradientonly affects autodiff β the value you pass in is returned as-is in the forward pass. - Scope matters: stop_gradient only blocks gradient for the wrapped sub-expression, not globally.
Problem
Implement selective_stop_grad(w_train, w_freeze, x) that:
-
Defines
loss(w_train) = sum((w_train * x + stop_gradient(w_freeze)) ** 2). -
Returns
jax.grad(loss)(w_train)β gradient w.r.t.w_trainonly.
-
w_train: scalar float. -
w_freeze: scalar float (frozen β gradient blocked). -
x: 1-D JAX array.
Returns: scalar β gradient of the loss w.r.t. w_train.
The analytic gradient is 2 * sum(x * (w_train * x + w_freeze)).
Hints
Sign in to attempt this problem and view the solution.