We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Pre-train then Fine-tune (Frozen Trunk)
Why this matters
Modern ML rarely trains from scratch. The dominant pattern is:
- Pre-train a big model on a broad task (massive corpus, lots of compute).
- Fine-tune on your downstream task (small dataset, minutes- to-hours of compute), usually with the trunk frozen and only the head training.
This works because pre-training learns general-purpose features that transfer; fine-tuning just adapts the output head to your label space. Examples in the wild:
- NLP: BERT pre-trained on masked-LM, fine-tuned with a classifier head.
- Vision: ViT pre-trained on ImageNet-21k, fine-tuned with a different classifier head.
- CLIP: pre-trained contrastively, frozen and used as a feature extractor.
Freezing the trunk has two big benefits:
- Speed: no grads flow through 99 % of the params; backprop is much cheaper.
- Stability: the trunk’s weights are known to be useful; letting them drift on a tiny fine-tune set risks catastrophic forgetting (loss of pretrained knowledge).
The classic counter-pattern is full fine-tuning (no freeze), which sometimes works better but needs more data and careful LR scheduling. Modern compromise: LoRA (only train low-rank adapters), which is covered later.
How to freeze in Optax
Two equally-valid ways:
Way A — optax.multi_transform (used in the reference solution):
labels = jax.tree_util.tree_map_with_path(
lambda path, _v: "frozen" if "Dense_0" in path_to_str(path) else "trainable",
params,
)
tx = optax.multi_transform(
{
"trainable": optax.sgd(lr_fine),
"frozen": optax.set_to_zero(), # zero updates ⇒ params don't move
},
labels,
)
set_to_zero() is a transformation that returns 0-magnitude updates
no matter what gradient you give it. Combined with a label, it
makes selected leaves immutable.
Way B — optax.masked:
mask = jax.tree_util.tree_map_with_path(
lambda path, _v: "Dense_1" in path_to_str(path), # True = trainable
params,
)
tx = optax.masked(optax.sgd(lr_fine), mask=mask)
optax.masked(t, mask) runs transformation t only on params where
mask == True. Param leaves where mask == False get zero updates.
Both work. multi_transform scales better (3+ groups, different
LRs); masked is simpler for the binary case.
Note on tree_map_with_path
The path arg is a tuple of DictKey / SequenceKey / etc.
objects. To get a usable string for matching, use:
any(getattr(k, "key", None) == "Dense_0" for k in path)
This works whether the path entry is a DictKey (has .key) or
something else (returns None, comparison fails, falls through).
The recipe
-
Build
MLP = Dense(8) -> relu -> Dense(1). -
Init with
PRNGKey(seed)andx_pre. -
Pre-train for 3 SGD steps at
lr_preon(x_pre, y_pre). -
Fine-tune for
finetune_stepsSGD steps atlr_fineon(x_fine, y_fine), freezingDense_0(the trunk). -
Return
[final_finetune_loss]as a 1-D(1,).
Inputs:
-
seed: float (cast to int). -
x_pre: 2-D(N_p, D)pre-train features. -
y_pre: 1-D(N_p,)pre-train targets. -
x_fine: 2-D(N_f, D)fine-tune features. -
y_fine: 1-D(N_f,)fine-tune targets. -
lr_pre,lr_fine: floats — SGD learning rates. -
finetune_steps: float (cast to int) — fine-tune step count.
Output: 1-D (1,) — [loss_after_finetuning].
Hints
Sign in to attempt this problem and view the solution.