hard primitives

Pre-train then Fine-tune (Frozen Trunk)

Why this matters

Modern ML rarely trains from scratch. The dominant pattern is:

  1. Pre-train a big model on a broad task (massive corpus, lots of compute).
  2. Fine-tune on your downstream task (small dataset, minutes- to-hours of compute), usually with the trunk frozen and only the head training.

This works because pre-training learns general-purpose features that transfer; fine-tuning just adapts the output head to your label space. Examples in the wild:

  • NLP: BERT pre-trained on masked-LM, fine-tuned with a classifier head.
  • Vision: ViT pre-trained on ImageNet-21k, fine-tuned with a different classifier head.
  • CLIP: pre-trained contrastively, frozen and used as a feature extractor.

Freezing the trunk has two big benefits:

  • Speed: no grads flow through 99 % of the params; backprop is much cheaper.
  • Stability: the trunk’s weights are known to be useful; letting them drift on a tiny fine-tune set risks catastrophic forgetting (loss of pretrained knowledge).

The classic counter-pattern is full fine-tuning (no freeze), which sometimes works better but needs more data and careful LR scheduling. Modern compromise: LoRA (only train low-rank adapters), which is covered later.

How to freeze in Optax

Two equally-valid ways:

Way A — optax.multi_transform (used in the reference solution):

labels = jax.tree_util.tree_map_with_path(
    lambda path, _v: "frozen" if "Dense_0" in path_to_str(path) else "trainable",
    params,
)
tx = optax.multi_transform(
    {
        "trainable": optax.sgd(lr_fine),
        "frozen":    optax.set_to_zero(),    # zero updates ⇒ params don't move
    },
    labels,
)

set_to_zero() is a transformation that returns 0-magnitude updates no matter what gradient you give it. Combined with a label, it makes selected leaves immutable.

Way B — optax.masked:

mask = jax.tree_util.tree_map_with_path(
    lambda path, _v: "Dense_1" in path_to_str(path),  # True = trainable
    params,
)
tx = optax.masked(optax.sgd(lr_fine), mask=mask)

optax.masked(t, mask) runs transformation t only on params where mask == True. Param leaves where mask == False get zero updates.

Both work. multi_transform scales better (3+ groups, different LRs); masked is simpler for the binary case.

Note on tree_map_with_path

The path arg is a tuple of DictKey / SequenceKey / etc. objects. To get a usable string for matching, use:

any(getattr(k, "key", None) == "Dense_0" for k in path)

This works whether the path entry is a DictKey (has .key) or something else (returns None, comparison fails, falls through).

The recipe

  1. Build MLP = Dense(8) -> relu -> Dense(1).
  2. Init with PRNGKey(seed) and x_pre.
  3. Pre-train for 3 SGD steps at lr_pre on (x_pre, y_pre).
  4. Fine-tune for finetune_steps SGD steps at lr_fine on (x_fine, y_fine), freezing Dense_0 (the trunk).
  5. Return [final_finetune_loss] as a 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x_pre: 2-D (N_p, D) pre-train features.
  • y_pre: 1-D (N_p,) pre-train targets.
  • x_fine: 2-D (N_f, D) fine-tune features.
  • y_fine: 1-D (N_f,) fine-tune targets.
  • lr_pre, lr_fine: floats — SGD learning rates.
  • finetune_steps: float (cast to int) — fine-tune step count.

Output: 1-D (1,)[loss_after_finetuning].

Hints

flax training fine-tuning

Sign in to attempt this problem and view the solution.