We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Param Freezing via Grad Zeroing
Why this matters
There are two equally-correct ways to “freeze” a layer in Flax:
Way A — optax.set_to_zero() via multi_transform (covered in
pos 73). The optimizer produces zero updates for the frozen group,
so apply_updates(p, 0) == p.
Way B — Zero the GRADIENTS before they reach the optimizer. The optimizer never even sees a non-zero gradient for those params, so it can’t move them.
grads = jax.grad(loss_fn)(params)
grads = jax.tree_util.tree_map_with_path(
lambda path, g: jnp.zeros_like(g) if "Dense_0" in path_str(path) else g,
grads,
)
state = state.apply_gradients(grads=grads)
Both produce identical results for plain SGD. They differ for optimizers with internal state:
- Adam, RMSProp, Adagrad: have running statistics (momentum, variance, etc.). Way A keeps the optimizer’s stats fresh on frozen params (grads still flow into the running averages, the transform just zeroes the update); Way B zeros the grad at the source so the running averages also collapse to zero.
- For pure freezing semantics (“don’t move these”), both are correct. For “freeze and resume training later”, Way A preserves momentum across the freeze.
For most fine-tuning cases, the difference doesn’t matter (you’re not unfreezing partway through). But if you ever ARE, prefer Way A.
What’s jax.tree_util.tree_map_with_path?
Same as tree_map, but the function gets (path, leaf) instead
of just leaf. path is a tuple of KeyPath entries — for a
nested dict like params["Dense_0"]["kernel"], the path is
(DictKey(key="Dense_0"), DictKey(key="kernel")).
To check whether the path mentions "Dense_0":
any(getattr(k, "key", None) == "Dense_0" for k in path)
getattr(k, "key", None) is safe for any path entry — for
SequenceKey (list index) it returns None and the comparison
falls through.
Verifying the freeze worked
The cleanest sanity check: compare the parameter’s L2 norm before and after the step. If the param didn’t move, the norm is unchanged:
norm_before = jnp.linalg.norm(params["Dense_0"]["kernel"])
# ... step ...
norm_after = jnp.linalg.norm(new_params["Dense_0"]["kernel"])
assert norm_after == norm_before
The grader uses this exact check.
Pitfalls
-
Wrong layer name: Flax names submodules
Dense_0,Dense_1, … automatically. Printparams.keys()to verify before writing the mask. -
Confusing
tree_mapwithtree_map_with_path: plaintree_map(f, tree)callsf(leaf)— no path info. You can’t decide which leaves to zero without the path. -
Skipping
jnp.zeros_like(g)and using0.0: the optimizer expects an array of the same shape/dtype, not a scalar. Usezeros_like.
Problem
Build the same MLP as in pos 73: Dense(8) -> relu -> Dense(1).
Run ONE training step on (x, y) with MSE loss, but FREEZE the
first Dense layer (Dense_0) by zeroing its gradients via
tree_map_with_path before calling the optimizer.
Return:
[loss_before, frozen_kernel_norm_after, frozen_kernel_norm_before]
The two norms should be EQUAL — that’s the proof that the freeze worked.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N,). -
lr: float — SGD learning rate.
Output: 1-D (3,) — [loss_before, norm_after, norm_before].
Hints
Sign in to attempt this problem and view the solution.