We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Param Surgery — Freeze First Dense
Why this matters
The previous two surgery problems (pos 93, 94) reached into a params pytree and replaced values directly: swap a kernel, zero a kernel. This problem applies the same surgical mindset to a different target — the gradient pytree — to freeze a layer.
Freezing a layer means: during a training step, the optimizer must not move its parameters. Use cases:
- Fine-tuning: pre-train a backbone, freeze it, train only the head on a downstream task.
- Curriculum freezing: train layer-by-layer, unfreezing one at a time.
- Architecture search: freeze everything but the candidate block to isolate its contribution.
The cleanest, most surgical way: zero the gradients of the frozen
params before the optimizer sees them. With zero gradient,
param ← param - lr * 0 == param. No movement, no change.
Position 76 (flax-param-freezing) introduced the same technique
in a “training” context. Here we re-cast it as surgery —
reaching into a pytree at known paths and replacing leaves. Same
skill, different lens. The training problem stressed what
happens; this problem stresses the act of reaching in.
The pattern: tree_map_with_path
jax.tree_util.tree_map walks every leaf in a pytree, but it
doesn’t know where each leaf lives. To freeze a specific
layer, you need the path — the sequence of dict keys you walked
to reach the leaf.
grads_zeroed = jax.tree_util.tree_map_with_path(
lambda path, g: jnp.zeros_like(g)
if any(getattr(k, "key", None) == "Dense_0" for k in path)
else g,
grads,
)
path is a tuple of KeyPath entries. For
grads["Dense_0"]["kernel"], the path is
(DictKey(key="Dense_0"), DictKey(key="kernel")). The check
getattr(k, "key", None) == "Dense_0" works for any key path
entry — non-DictKey entries (like SequenceKey for list indices)
safely return None and the comparison falls through.
Why “before the optimizer sees them”?
Plain SGD: param ← param - lr · grad. Zero grad means zero
update — both optax.set_to_zero() (a transform on updates) and
grad-zeroing reach the same result.
But for stateful optimizers (Adam, RMSProp, momentum SGD), the running averages absorb whatever grad you feed them. Zeroing grad at the source means the running averages also collapse to zero on frozen params. Zeroing the update keeps the running averages updating normally — useful when you might unfreeze later. For a one-shot freeze with SGD (this problem), both routes are identical.
Verifying the freeze worked: norm sandwich
The cleanest sanity check is parameter-identity:
norm_before = jnp.linalg.norm(params["Dense_0"]["kernel"])
# ... step ...
norm_after = jnp.linalg.norm(new_params["Dense_0"]["kernel"])
assert float(norm_after) == float(norm_before)
If the kernel didn’t move, its L2 norm didn’t change. The grader
asserts equality (within float tolerance) of norm_after and
norm_before directly.
Note: comparing kernel_after == kernel_before element-wise also
works, but reducing to a scalar is friendlier to the test
infrastructure (single float vs whole array).
Architecture
Same MLP as pos 76:
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(8)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
Two Dense layers, named Dense_0 and Dense_1 automatically by
Flax. We freeze Dense_0 (kernel + bias) and train Dense_1.
Common pitfalls
- Zeroing the kernel itself instead of its grad: that’s a different surgery (pos 94). Here we keep the param values intact and zero the gradient on its way to the optimizer.
-
Forgetting to capture
norm_before: the model’s params change during the step. You must save the norm BEFORE the step, then check it AFTER on the new params. -
Wrong dense name: it’s
Dense_0, notdense_0. Flax auto-names with the class name. - Mixing up grad zeroing and update zeroing: for SGD they’re equivalent, for Adam they’re not. The reference uses grad zeroing, which is what this problem teaches.
Problem
Build the MLP Dense(8) → relu → Dense(1). Compute MSE loss on
(x, y). Run ONE SGD step with the FIRST Dense layer frozen by
zeroing its gradients via tree_map_with_path. Return:
[loss, frozen_kernel_norm_after, frozen_kernel_norm_before]
The two norms must be EQUAL — that’s the proof that the freeze worked.
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, D)input. -
y: 1-D(N,)targets. -
lr: float — SGD learning rate.
Output: 1-D (3,) — [loss, norm_after, norm_before].
Hints
Sign in to attempt this problem and view the solution.