hard primitives

Param Freezing via Grad Zeroing

Why this matters

There are two equally-correct ways to “freeze” a layer in Flax:

Way A — optax.set_to_zero() via multi_transform (covered in pos 73). The optimizer produces zero updates for the frozen group, so apply_updates(p, 0) == p.

Way B — Zero the GRADIENTS before they reach the optimizer. The optimizer never even sees a non-zero gradient for those params, so it can’t move them.

grads = jax.grad(loss_fn)(params)
grads = jax.tree_util.tree_map_with_path(
    lambda path, g: jnp.zeros_like(g) if "Dense_0" in path_str(path) else g,
    grads,
)
state = state.apply_gradients(grads=grads)

Both produce identical results for plain SGD. They differ for optimizers with internal state:

  • Adam, RMSProp, Adagrad: have running statistics (momentum, variance, etc.). Way A keeps the optimizer’s stats fresh on frozen params (grads still flow into the running averages, the transform just zeroes the update); Way B zeros the grad at the source so the running averages also collapse to zero.
  • For pure freezing semantics (“don’t move these”), both are correct. For “freeze and resume training later”, Way A preserves momentum across the freeze.

For most fine-tuning cases, the difference doesn’t matter (you’re not unfreezing partway through). But if you ever ARE, prefer Way A.

What’s jax.tree_util.tree_map_with_path?

Same as tree_map, but the function gets (path, leaf) instead of just leaf. path is a tuple of KeyPath entries — for a nested dict like params["Dense_0"]["kernel"], the path is (DictKey(key="Dense_0"), DictKey(key="kernel")).

To check whether the path mentions "Dense_0":

any(getattr(k, "key", None) == "Dense_0" for k in path)

getattr(k, "key", None) is safe for any path entry — for SequenceKey (list index) it returns None and the comparison falls through.

Verifying the freeze worked

The cleanest sanity check: compare the parameter’s L2 norm before and after the step. If the param didn’t move, the norm is unchanged:

norm_before = jnp.linalg.norm(params["Dense_0"]["kernel"])
# ... step ...
norm_after = jnp.linalg.norm(new_params["Dense_0"]["kernel"])
assert norm_after == norm_before

The grader uses this exact check.

Pitfalls

  • Wrong layer name: Flax names submodules Dense_0, Dense_1, … automatically. Print params.keys() to verify before writing the mask.
  • Confusing tree_map with tree_map_with_path: plain tree_map(f, tree) calls f(leaf) — no path info. You can’t decide which leaves to zero without the path.
  • Skipping jnp.zeros_like(g) and using 0.0: the optimizer expects an array of the same shape/dtype, not a scalar. Use zeros_like.

Problem

Build the same MLP as in pos 73: Dense(8) -> relu -> Dense(1). Run ONE training step on (x, y) with MSE loss, but FREEZE the first Dense layer (Dense_0) by zeroing its gradients via tree_map_with_path before calling the optimizer.

Return:

[loss_before, frozen_kernel_norm_after, frozen_kernel_norm_before]

The two norms should be EQUAL — that’s the proof that the freeze worked.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N,).
  • lr: float — SGD learning rate.

Output: 1-D (3,)[loss_before, norm_after, norm_before].

Hints

flax training freezing

Sign in to attempt this problem and view the solution.