hard primitives

NNX Surgery Freeze

Why this matters

“Freeze the first N layers” is the canonical move of transfer learning: load a pretrained backbone, freeze it, train only the new head. Done correctly, the backbone’s parameters are bit-for-bit identical before and after the optimizer step. Done incorrectly — you’ll quietly destroy the pretrained features.

There are two clean ways to freeze in nnx 0.12.6:

  1. Mask the gradients before optimizer.update. Compute grads normally with nnx.value_and_grad, then zero the entries corresponding to frozen params. SGD on a zero gradient is a no-op: θ ← θ - lr * 0 = θ.
  2. Use optax.multi_transform. Label parts of the param tree and route them to different sub-optimizers. The frozen branch gets optax.set_to_zero().

This problem walks through option (1). It’s slightly less composable than multi_transform but easier to reason about and requires no new optax machinery — just the same tree_map_with_path you saw in pos 14 (nnx-state-transform).

The recipe

# 1. Standard forward + grads.
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)

# 2. Zero the entries you want frozen.
def zero_first_layer(path, leaf):
    name = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
    if name.startswith('l1'):
        return jnp.zeros_like(leaf)
    return leaf

masked_grads = jax.tree_util.tree_map_with_path(zero_first_layer, grads)

# 3. Step. Frozen params receive a zero gradient and don't change.
optimizer.update(model, masked_grads)

The semantics: the optimizer applies its update rule to every leaf in masked_grads. With optax.sgd, the update rule is θ ← θ - lr * g. For frozen leaves, g = 0, so θ ← θ. The parameter is mathematically unchanged.

Why grad masking, not param freezing

A naive alternative: train, then OVERWRITE the frozen params with their pre-step values:

pre = jnp.copy(model.l1.kernel.value)
optimizer.update(model, grads)
model.l1.kernel.value = pre   # restore

This works for SGD with no momentum but breaks for any optimizer with internal state. Adam, for example, accumulates m and v statistics for every parameter. Even if you reset the param, the optimizer’s internal stats now reflect a step that “happened” — and the next step uses those skewed stats. Grad masking sidesteps the issue: set_to_zero-style behavior is what optax’s multi_transform formalizes.

API: jax.tree_util.tree_map_with_path

Same path-based transform you saw in nnx-state-transform (pos 14). The path is a tuple of structured KeyEntry objects; the '.'.join(str(p.key) ...) pattern flattens it to a string for substring matching.

For a 2-layer model with attribute names l1 and l2, paths look like:

  • ('l1', 'kernel', '.value') — first layer’s kernel
  • ('l1', 'bias', '.value') — first layer’s bias
  • ('l2', 'kernel', '.value') — second layer’s kernel
  • ('l2', 'bias', '.value') — second layer’s bias

name.startswith('l1') matches the first two and zeros their gradients.

Verifying the freeze worked

Save the kernel norm before the step. Save it again after. The two must be exactly equal — not just close, equal. If they differ, you’ve still got a non-zero gradient flowing somewhere.

before = float(jnp.linalg.norm(model.l1.kernel.value))
optimizer.update(model, masked_grads)
after = float(jnp.linalg.norm(model.l1.kernel.value))
assert before == after          # bit-for-bit

The test returns [loss, after, before] so the harness can verify after == before.

Common pitfalls

  • Forgetting to zero the BIAS gradients too. If only kernel gradients are masked, the bias still trains. Use name.startswith('l1') so all of l1’s leaves are caught.
  • Building the optimizer inside the step. Optimizers have state (e.g., Adam’s m/v). Build once, reuse.
  • Returning the loss after the step. The loss must be computed DURING the forward pass that generated the gradients — that’s what nnx.value_and_grad returns. Don’t recompute it after the update; that gives you the post-step loss, not the step’s loss.
  • Naming with both 'l1' and 'l11' would collide on startswith. For this problem only l1 and l2 exist, so it’s safe.

Problem

Write freeze_first_dense(seed, x, y, lr):

  1. Build a TwoLayer(nnx.Module) with l1 = nnx.Linear(in, hidden) and l2 = nnx.Linear(hidden, out). hidden = 4. Forward: l2(relu(l1(x))).
  2. Build optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  3. Save frozen_kernel_norm_before = float(jnp.linalg.norm(model.l1.kernel.value)).
  4. Compute loss, grads = nnx.value_and_grad(loss_fn)(model, x, y) with MSE loss.
  5. Use jax.tree_util.tree_map_with_path to zero every grad leaf whose path starts with 'l1'.
  6. optimizer.update(model, masked_grads).
  7. Save frozen_kernel_norm_after = float(jnp.linalg.norm(model.l1.kernel.value)).
  8. Return jnp.array([float(loss), frozen_kernel_norm_after, frozen_kernel_norm_before]).

The two norms MUST be equal. If they aren’t, the freeze leaked.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (N, in_features).
  • y: 2-D (N, out_features).
  • lr: float.

Output: length-3 [loss, norm_after, norm_before] with norm_after == norm_before.

Hints

flax nnx freeze transfer-learning tree-map

Sign in to attempt this problem and view the solution.