We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Surgery Freeze
Why this matters
“Freeze the first N layers” is the canonical move of transfer learning: load a pretrained backbone, freeze it, train only the new head. Done correctly, the backbone’s parameters are bit-for-bit identical before and after the optimizer step. Done incorrectly — you’ll quietly destroy the pretrained features.
There are two clean ways to freeze in nnx 0.12.6:
-
Mask the gradients before
optimizer.update. Compute grads normally withnnx.value_and_grad, then zero the entries corresponding to frozen params. SGD on a zero gradient is a no-op:θ ← θ - lr * 0 = θ. -
Use
optax.multi_transform. Label parts of the param tree and route them to different sub-optimizers. The frozen branch getsoptax.set_to_zero().
This problem walks through option (1). It’s slightly less
composable than multi_transform but easier to reason about and
requires no new optax machinery — just the same tree_map_with_path
you saw in pos 14 (nnx-state-transform).
The recipe
# 1. Standard forward + grads.
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
# 2. Zero the entries you want frozen.
def zero_first_layer(path, leaf):
name = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
if name.startswith('l1'):
return jnp.zeros_like(leaf)
return leaf
masked_grads = jax.tree_util.tree_map_with_path(zero_first_layer, grads)
# 3. Step. Frozen params receive a zero gradient and don't change.
optimizer.update(model, masked_grads)
The semantics: the optimizer applies its update rule to every leaf
in masked_grads. With optax.sgd, the update rule is
θ ← θ - lr * g. For frozen leaves, g = 0, so θ ← θ. The
parameter is mathematically unchanged.
Why grad masking, not param freezing
A naive alternative: train, then OVERWRITE the frozen params with their pre-step values:
pre = jnp.copy(model.l1.kernel.value)
optimizer.update(model, grads)
model.l1.kernel.value = pre # restore
This works for SGD with no momentum but breaks for any optimizer
with internal state. Adam, for example, accumulates m and v
statistics for every parameter. Even if you reset the param,
the optimizer’s internal stats now reflect a step that “happened”
— and the next step uses those skewed stats. Grad masking sidesteps
the issue: set_to_zero-style behavior is what optax’s
multi_transform formalizes.
API: jax.tree_util.tree_map_with_path
Same path-based transform you saw in nnx-state-transform (pos 14).
The path is a tuple of structured KeyEntry objects; the
'.'.join(str(p.key) ...) pattern flattens it to a string for
substring matching.
For a 2-layer model with attribute names l1 and l2, paths look
like:
-
('l1', 'kernel', '.value')— first layer’s kernel -
('l1', 'bias', '.value')— first layer’s bias -
('l2', 'kernel', '.value')— second layer’s kernel -
('l2', 'bias', '.value')— second layer’s bias
name.startswith('l1') matches the first two and zeros their
gradients.
Verifying the freeze worked
Save the kernel norm before the step. Save it again after. The two must be exactly equal — not just close, equal. If they differ, you’ve still got a non-zero gradient flowing somewhere.
before = float(jnp.linalg.norm(model.l1.kernel.value))
optimizer.update(model, masked_grads)
after = float(jnp.linalg.norm(model.l1.kernel.value))
assert before == after # bit-for-bit
The test returns [loss, after, before] so the harness can verify
after == before.
Common pitfalls
-
Forgetting to zero the BIAS gradients too. If only
kernelgradients are masked, the bias still trains. Usename.startswith('l1')so all of l1’s leaves are caught. -
Building the optimizer inside the step. Optimizers have state
(e.g., Adam’s
m/v). Build once, reuse. -
Returning the loss after the step. The loss must be computed
DURING the forward pass that generated the gradients — that’s
what
nnx.value_and_gradreturns. Don’t recompute it after the update; that gives you the post-step loss, not the step’s loss. -
Naming with both
'l1'and'l11'would collide onstartswith. For this problem onlyl1andl2exist, so it’s safe.
Problem
Write freeze_first_dense(seed, x, y, lr):
-
Build a
TwoLayer(nnx.Module)withl1 = nnx.Linear(in, hidden)andl2 = nnx.Linear(hidden, out).hidden = 4. Forward:l2(relu(l1(x))). -
Build
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Save
frozen_kernel_norm_before = float(jnp.linalg.norm(model.l1.kernel.value)). -
Compute
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)with MSE loss. -
Use
jax.tree_util.tree_map_with_pathto zero every grad leaf whose path starts with'l1'. -
optimizer.update(model, masked_grads). -
Save
frozen_kernel_norm_after = float(jnp.linalg.norm(model.l1.kernel.value)). -
Return
jnp.array([float(loss), frozen_kernel_norm_after, frozen_kernel_norm_before]).
The two norms MUST be equal. If they aren’t, the freeze leaked.
Inputs:
-
seed: int (passed as float). -
x: 2-D(N, in_features). -
y: 2-D(N, out_features). -
lr: float.
Output: length-3 [loss, norm_after, norm_before] with norm_after == norm_before.
Hints
Sign in to attempt this problem and view the solution.