hard primitives

Train Step with Frozen Params (Mask)

Why this matters

Fine-tuning large models often requires freezing some parameters (e.g. a transformer body) while updating others (e.g. a task-specific head). The simplest way to achieve this in JAX/Optax is to zero out gradients at frozen positions before calling the optimizer.

The recipe

masked_grads = jnp.where(frozen_mask >= 0.5, 0.0, grads)
optimizer  = optax.sgd(lr)
opt_state  = optimizer.init(params)
updates, _ = optimizer.update(masked_grads, opt_state, params)
return optax.apply_updates(params, updates)

A value of 1.0 in frozen_mask marks a frozen position. For more complex per-module freezing, see optax.masked.

Common pitfalls

  • Apply the mask before optimizer.update, not after apply_updates. Masking after the update would corrupt the frozen parameter values.
  • jnp.where(cond, 0.0, grads) zeros where cond is True (frozen=1.0).
  • In production, optax.masked lets you assign different transforms to different parts of a pytree (e.g. different layers).

Inputs

  • params: 1-D JAX array โ€” model parameters.
  • frozen_mask: 1-D JAX array of 0.0/1.0 โ€” 1.0 marks frozen positions.
  • grads: 1-D JAX array โ€” gradients.
  • lr: scalar โ€” SGD learning rate.

Output

1-D array โ€” params after one SGD step with frozen positions unchanged.

Hints

optax training frozen

Sign in to attempt this problem and view the solution.