We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Train Step with Frozen Params (Mask)
Why this matters
Fine-tuning large models often requires freezing some parameters (e.g. a transformer body) while updating others (e.g. a task-specific head). The simplest way to achieve this in JAX/Optax is to zero out gradients at frozen positions before calling the optimizer.
The recipe
masked_grads = jnp.where(frozen_mask >= 0.5, 0.0, grads)
optimizer = optax.sgd(lr)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(masked_grads, opt_state, params)
return optax.apply_updates(params, updates)
A value of 1.0 in frozen_mask marks a frozen position. For more
complex per-module freezing, see optax.masked.
Common pitfalls
-
Apply the mask before
optimizer.update, not afterapply_updates. Masking after the update would corrupt the frozen parameter values. -
jnp.where(cond, 0.0, grads)zeros wherecondis True (frozen=1.0). -
In production,
optax.maskedlets you assign different transforms to different parts of a pytree (e.g. different layers).
Inputs
-
params: 1-D JAX array โ model parameters. -
frozen_mask: 1-D JAX array of 0.0/1.0 โ 1.0 marks frozen positions. -
grads: 1-D JAX array โ gradients. -
lr: scalar โ SGD learning rate.
Output
1-D array โ params after one SGD step with frozen positions unchanged.
Hints
optax
training
frozen
Sign in to attempt this problem and view the solution.