medium primitives

Partial Init — Warm-Start From Smaller Checkpoint

Why this matters

“Partial init” is the canonical way to expand a pre-trained model: add a layer (or a few) to a model you’ve already trained, and warm-start the new bigger model from the smaller checkpoint. The first N layers come from the trained model; the new tail layers are freshly initialized and learn from scratch.

Real cases:

  • Stacking pretraining — train an L-layer model, then deepen to L+k layers and continue pretraining (used in early BERT scaling experiments and in modular curricula like ProgressiveBERT).
  • Bolt-on heads — load a backbone checkpoint, add task-specific output projections that init fresh.
  • Architecture surgery — splice a new block between layers i and i+1 of a trained network without disturbing the rest. The pre-existing blocks come from disk; the new block gets fresh init.

Flax makes this trivial because params are plain (frozen) dicts of arrays. To partial-init: init the BIG model normally to get a full-shape param tree, then walk the SMALL model’s param tree and overwrite the matching keys in the big tree. Anything not overwritten stays at its fresh init.

The pattern: init both, copy what overlaps

small = MLP(num_layers=N)        # already-trained shape (warm)
big   = MLP(num_layers=N+k)      # target shape (cold)

rng_s, rng_b = jax.random.split(rng)
small_params = small.init(rng_s, x)        # warm starting point
big_params   = big.init(rng_b, x)          # full target shape

# Partial init: overlay small_params onto big_params.
new_inner = dict(big_params["params"])
for i in range(N):
    key = f"Dense_{i}"
    new_inner[key] = small_params["params"][key]
new_params = {"params": new_inner}

The bigger model’s Dense_0 ... Dense_{N-1} come from the small model. Dense_N ... Dense_{N+k-1} stay at their fresh init.

In production, small_params would be loaded from a checkpoint file (flax.serialization.from_bytes) rather than re-initialized. The mechanism is identical — copy keys from one params dict into another.

Param naming convention

Flax auto-names submodules in a nn.compact module by class name plus a counter: the i-th nn.Dense becomes Dense_i. So:

  • Small model with N=2: keys are Dense_0, Dense_1.
  • Big model with N=4: keys are Dense_0, Dense_1, Dense_2, Dense_3.

Dense_0 and Dense_1 exist in both — that’s the overlap. Their shapes must also match (same features dimension, same input dim). If they don’t, partial init silently produces wrong shapes.

Counting params: leaves and .size

A clean way to verify the result has the right total parameter count:

leaves = jax.tree_util.tree_leaves(new_params)
total = sum(int(l.size) for l in leaves)

tree_leaves flattens the pytree to a list of arrays. Each array’s .size is its number of scalar elements. The sum over all leaves is the total parameter count.

For a single nn.Dense(features=F) applied to a (N, D) input: kernel is (D, F) (D·F params) and bias is (F,) (F params), total D·F + F = F·(D+1).

Why init the big model FIRST?

Two reasons:

  1. Get the right shape: big.init(rng_b, x) is the only way to know the shapes for Dense_N, ..., Dense_{N+k-1} without re-deriving them by hand.
  2. Init the new layers naturally: Flax’s default init is Lecun-normal for kernels, zeros for biases. By using init, you get those defaults for free on the new layers.

Skipping the big init and trying to construct Dense_N from scratch is fragile and error-prone — let Flax do it.

Common pitfalls

  • Mismatched shapes between small and big: features must match. If you grew features, partial init becomes “partial partial” — only sub-blocks overlap. This problem keeps features constant.
  • Mutating the original dict: new_inner = dict(big_params['params']) makes a shallow copy of the outer dict so we can reassign its keys without touching the original. Spread ({**...}) works too.
  • Forgetting the params outer key: in Flax, params live under {"params": {...}}. The outer level is the collection name; everything in this problem stays in the params collection.
  • Using the same rng for both inits: not strictly wrong, but conceptually the small and big inits represent different checkpoints — splitting the rng makes the intent clear.

Problem

Build two MLPs:

  • small = MLP(num_layers=num_full_layers).
  • big = MLP(num_layers=num_full_layers + num_extra_layers).

Both use nn.Dense(features) per layer (same features for every layer in both models).

Init both with split rngs. Build a new_params whose first num_full_layers Dense blocks come from small_params and whose remaining blocks stay at the big model’s fresh init. Return:

[float(total_param_count_of_new_params), float(num_full_layers + num_extra_layers)]

The first number is the leaf-summed total parameter count; the second is the number of layers in the resulting big model.

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, in_dim).
  • num_full_layers: float (cast to int) — layers in the small (warm) model.
  • num_extra_layers: float (cast to int) — additional layers in the big model.
  • features: float (cast to int) — Dense output dimension (constant across layers).

Output: 1-D (2,).

Hints

flax params checkpoint

Sign in to attempt this problem and view the solution.