medium primitives

Param Surgery — Freeze First Dense

Why this matters

The previous two surgery problems (pos 93, 94) reached into a params pytree and replaced values directly: swap a kernel, zero a kernel. This problem applies the same surgical mindset to a different target — the gradient pytree — to freeze a layer.

Freezing a layer means: during a training step, the optimizer must not move its parameters. Use cases:

  • Fine-tuning: pre-train a backbone, freeze it, train only the head on a downstream task.
  • Curriculum freezing: train layer-by-layer, unfreezing one at a time.
  • Architecture search: freeze everything but the candidate block to isolate its contribution.

The cleanest, most surgical way: zero the gradients of the frozen params before the optimizer sees them. With zero gradient, param ← param - lr * 0 == param. No movement, no change.

Position 76 (flax-param-freezing) introduced the same technique in a “training” context. Here we re-cast it as surgery — reaching into a pytree at known paths and replacing leaves. Same skill, different lens. The training problem stressed what happens; this problem stresses the act of reaching in.

The pattern: tree_map_with_path

jax.tree_util.tree_map walks every leaf in a pytree, but it doesn’t know where each leaf lives. To freeze a specific layer, you need the path — the sequence of dict keys you walked to reach the leaf.

grads_zeroed = jax.tree_util.tree_map_with_path(
    lambda path, g: jnp.zeros_like(g)
                    if any(getattr(k, "key", None) == "Dense_0" for k in path)
                    else g,
    grads,
)

path is a tuple of KeyPath entries. For grads["Dense_0"]["kernel"], the path is (DictKey(key="Dense_0"), DictKey(key="kernel")). The check getattr(k, "key", None) == "Dense_0" works for any key path entry — non-DictKey entries (like SequenceKey for list indices) safely return None and the comparison falls through.

Why “before the optimizer sees them”?

Plain SGD: param ← param - lr · grad. Zero grad means zero update — both optax.set_to_zero() (a transform on updates) and grad-zeroing reach the same result.

But for stateful optimizers (Adam, RMSProp, momentum SGD), the running averages absorb whatever grad you feed them. Zeroing grad at the source means the running averages also collapse to zero on frozen params. Zeroing the update keeps the running averages updating normally — useful when you might unfreeze later. For a one-shot freeze with SGD (this problem), both routes are identical.

Verifying the freeze worked: norm sandwich

The cleanest sanity check is parameter-identity:

norm_before = jnp.linalg.norm(params["Dense_0"]["kernel"])
# ... step ...
norm_after = jnp.linalg.norm(new_params["Dense_0"]["kernel"])
assert float(norm_after) == float(norm_before)

If the kernel didn’t move, its L2 norm didn’t change. The grader asserts equality (within float tolerance) of norm_after and norm_before directly.

Note: comparing kernel_after == kernel_before element-wise also works, but reducing to a scalar is friendlier to the test infrastructure (single float vs whole array).

Architecture

Same MLP as pos 76:

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(8)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

Two Dense layers, named Dense_0 and Dense_1 automatically by Flax. We freeze Dense_0 (kernel + bias) and train Dense_1.

Common pitfalls

  • Zeroing the kernel itself instead of its grad: that’s a different surgery (pos 94). Here we keep the param values intact and zero the gradient on its way to the optimizer.
  • Forgetting to capture norm_before: the model’s params change during the step. You must save the norm BEFORE the step, then check it AFTER on the new params.
  • Wrong dense name: it’s Dense_0, not dense_0. Flax auto-names with the class name.
  • Mixing up grad zeroing and update zeroing: for SGD they’re equivalent, for Adam they’re not. The reference uses grad zeroing, which is what this problem teaches.

Problem

Build the MLP Dense(8) → relu → Dense(1). Compute MSE loss on (x, y). Run ONE SGD step with the FIRST Dense layer frozen by zeroing its gradients via tree_map_with_path. Return:

[loss, frozen_kernel_norm_after, frozen_kernel_norm_before]

The two norms must be EQUAL — that’s the proof that the freeze worked.

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, D) input.
  • y: 1-D (N,) targets.
  • lr: float — SGD learning rate.

Output: 1-D (3,)[loss, norm_after, norm_before].

Hints

flax surgery freezing

Sign in to attempt this problem and view the solution.