We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Partial Init — Warm-Start From Smaller Checkpoint
Why this matters
“Partial init” is the canonical way to expand a pre-trained
model: add a layer (or a few) to a model you’ve already trained,
and warm-start the new bigger model from the smaller checkpoint.
The first N layers come from the trained model; the new tail
layers are freshly initialized and learn from scratch.
Real cases:
-
Stacking pretraining — train an
L-layer model, then deepen toL+klayers and continue pretraining (used in early BERT scaling experiments and in modular curricula like ProgressiveBERT). - Bolt-on heads — load a backbone checkpoint, add task-specific output projections that init fresh.
-
Architecture surgery — splice a new block between layers
iandi+1of a trained network without disturbing the rest. The pre-existing blocks come from disk; the new block gets fresh init.
Flax makes this trivial because params are plain (frozen) dicts of arrays. To partial-init: init the BIG model normally to get a full-shape param tree, then walk the SMALL model’s param tree and overwrite the matching keys in the big tree. Anything not overwritten stays at its fresh init.
The pattern: init both, copy what overlaps
small = MLP(num_layers=N) # already-trained shape (warm)
big = MLP(num_layers=N+k) # target shape (cold)
rng_s, rng_b = jax.random.split(rng)
small_params = small.init(rng_s, x) # warm starting point
big_params = big.init(rng_b, x) # full target shape
# Partial init: overlay small_params onto big_params.
new_inner = dict(big_params["params"])
for i in range(N):
key = f"Dense_{i}"
new_inner[key] = small_params["params"][key]
new_params = {"params": new_inner}
The bigger model’s Dense_0 ... Dense_{N-1} come from the small
model. Dense_N ... Dense_{N+k-1} stay at their fresh init.
In production, small_params would be loaded from a checkpoint
file (flax.serialization.from_bytes) rather than re-initialized.
The mechanism is identical — copy keys from one params dict into
another.
Param naming convention
Flax auto-names submodules in a nn.compact module by class name
plus a counter: the i-th nn.Dense becomes Dense_i. So:
-
Small model with
N=2: keys areDense_0, Dense_1. -
Big model with
N=4: keys areDense_0, Dense_1, Dense_2, Dense_3.
Dense_0 and Dense_1 exist in both — that’s the overlap. Their
shapes must also match (same features dimension, same input
dim). If they don’t, partial init silently produces wrong shapes.
Counting params: leaves and .size
A clean way to verify the result has the right total parameter count:
leaves = jax.tree_util.tree_leaves(new_params)
total = sum(int(l.size) for l in leaves)
tree_leaves flattens the pytree to a list of arrays. Each
array’s .size is its number of scalar elements. The sum over
all leaves is the total parameter count.
For a single nn.Dense(features=F) applied to a (N, D) input:
kernel is (D, F) (D·F params) and bias is (F,) (F params),
total D·F + F = F·(D+1).
Why init the big model FIRST?
Two reasons:
-
Get the right shape:
big.init(rng_b, x)is the only way to know the shapes forDense_N, ..., Dense_{N+k-1}without re-deriving them by hand. -
Init the new layers naturally: Flax’s default init is
Lecun-normal for kernels, zeros for biases. By using
init, you get those defaults for free on the new layers.
Skipping the big init and trying to construct Dense_N from
scratch is fragile and error-prone — let Flax do it.
Common pitfalls
-
Mismatched shapes between small and big:
featuresmust match. If you grewfeatures, partial init becomes “partial partial” — only sub-blocks overlap. This problem keepsfeaturesconstant. -
Mutating the original dict:
new_inner = dict(big_params['params'])makes a shallow copy of the outer dict so we can reassign its keys without touching the original. Spread ({**...}) works too. -
Forgetting the
paramsouter key: in Flax, params live under{"params": {...}}. The outer level is the collection name; everything in this problem stays in theparamscollection. - Using the same rng for both inits: not strictly wrong, but conceptually the small and big inits represent different checkpoints — splitting the rng makes the intent clear.
Problem
Build two MLPs:
-
small = MLP(num_layers=num_full_layers). -
big = MLP(num_layers=num_full_layers + num_extra_layers).
Both use nn.Dense(features) per layer (same features for
every layer in both models).
Init both with split rngs. Build a new_params whose first
num_full_layers Dense blocks come from small_params and whose
remaining blocks stay at the big model’s fresh init. Return:
[float(total_param_count_of_new_params), float(num_full_layers + num_extra_layers)]
The first number is the leaf-summed total parameter count; the second is the number of layers in the resulting big model.
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, in_dim). -
num_full_layers: float (cast to int) — layers in the small (warm) model. -
num_extra_layers: float (cast to int) — additional layers in the big model. -
features: float (cast to int) — Dense output dimension (constant across layers).
Output: 1-D (2,).
Hints
Sign in to attempt this problem and view the solution.