hard primitives

Selective stop_gradient (Frozen Layers)

Why this matters

In transfer learning and multi-task models, you often want to freeze some parameters β€” let them contribute to the forward pass but receive no gradient update. lax.stop_gradient is the canonical JAX tool for this.

Wrapping a parameter with lax.stop_gradient(w) tells autodiff: β€œtreat this value as a constant during the backward pass.” The forward value is unchanged; only gradient flow is blocked.

This is the foundation for:

  • Frozen encoder layers in fine-tuning LLMs.
  • Target networks in RL (soft updates only, no gradient through target).
  • Straight-through estimators (stop_gradient for discrete ops).

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

def loss(w_train, w_fixed):
    return jnp.sum((w_train + lax.stop_gradient(w_fixed)) ** 2)

# Gradient only flows through w_train, not w_fixed
print(jax.grad(loss)(1.0, 5.0))   # 12.0  (2*(1+5)*1)

For a whole pytree: jax.tree.map(lax.stop_gradient, frozen_params).

Common pitfalls

  • Stopping the wrong parameter: wrapping w_train instead of w_freeze blocks ALL gradient, defeating the purpose.
  • Forward value is unchanged: stop_gradient only affects autodiff β€” the value you pass in is returned as-is in the forward pass.
  • Scope matters: stop_gradient only blocks gradient for the wrapped sub-expression, not globally.

Problem

Implement selective_stop_grad(w_train, w_freeze, x) that:

  1. Defines loss(w_train) = sum((w_train * x + stop_gradient(w_freeze)) ** 2).
  2. Returns jax.grad(loss)(w_train) β€” gradient w.r.t. w_train only.
  • w_train: scalar float.
  • w_freeze: scalar float (frozen β€” gradient blocked).
  • x: 1-D JAX array.

Returns: scalar β€” gradient of the loss w.r.t. w_train.

The analytic gradient is 2 * sum(x * (w_train * x + w_freeze)).

Hints

jax stop-gradient pytree freezing

Sign in to attempt this problem and view the solution.