We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Surgery Add Layer
Why this matters
Layer expansion / warm-starting is one of the standard transfer learning moves: you trained a small model, you want a deeper one, and you don’t want to throw away the smaller checkpoint. So you initialize the deeper model and copy the smaller model’s weights into the first N layers — the new layers stay at their fresh random init.
Why not just train from scratch? Two reasons:
- Compute economics. A few hundred GPU-hours of backbone training carries over for free.
- Optimization geometry. A randomly-initialized large model has different loss landscape than a partially-pretrained one; starting partially trained tends to converge faster on downstream data.
Stacking-up tricks like this are how scaling-law training is done in practice — train at small scale, expand, continue training.
In nnx the operation is just attribute assignment, layer by layer.
The recipe
small = MLP(in_features=D, num_layers=L_old, features=F, rngs=...)
big = MLP(in_features=D, num_layers=L_old + L_extra, features=F, rngs=...)
# Copy first L_old layers from small into big.
for i in range(L_old):
big.layers[i].kernel.value = small.layers[i].kernel.value
big.layers[i].bias.value = small.layers[i].bias.value
The remaining L_extra layers in big retain their fresh init.
No unfreeze, no path strings — just direct attribute writes.
Why use a different seed for the big model
If both models are built with the same seed, the first L_old layers of the big model will already happen to share param VALUES with the small model — so the copy is a no-op. The test would pass accidentally.
Use a different Rngs (e.g., nnx.Rngs(seed + 1000)) for the big
model. Now the copy is observable: without it, the big model’s
first L_old layers would have random init values; with it, they
match the small model. The downstream layers (L_old onward) keep
the big model’s fresh random init either way.
What about activations?
For a real warm-start you’d want the function of the small model
to be preserved at the start of training: big(x) ≈ small(x).
With a plain layer-copy and random new layers, that’s not quite
true — the random tail layers still scramble the output.
Common tricks to recover function-preservation:
-
Initialize new layers as identity (
Iforkernel, zero forbias). -
Skip-init: insert each new layer as a residual skip with
a zero-initialized branch —
x + 0 = x.
Both are cheap to add once you know the surgery primitives. This problem just covers the param copy itself; the function-preservation tricks are good follow-ups for self-study.
API recap: nnx.List indexing
nnx.List (pos 43, 92) supports plain Python indexing. Both these
are valid:
big.layers[0].kernel.value = small.layers[0].kernel.value # write
print(big.layers[2].bias.value.shape) # read
len(big.layers) works too. Iteration with for layer in big.layers works. It behaves like a list because it is a list,
just registered as a data pytree.
Common pitfalls
- Same seed for both models. Then the first L_old layers match by accident; the copy isn’t observable.
-
Off-by-one on the loop.
for i in range(L_old)— noteL_old, notL_old - 1orL_old + 1. -
Forgetting the bias.
kernelis the matrix,biasis the offset. Both need copying for the layer to behave identically. Half-copy means the bias is whatever the big model’s init produced (probably zero, but conceptually wrong). -
Cast to int.
num_old_layers,num_extra_layers,featuresarrive as floats. Cast them withint()before using them inrange(...)or as dimensions. -
Plain Python list, no
nnx.List. Same trap as pos 92 — the MLP layers won’t show up in the state tree.
Problem
Write add_layer_warm_start(seed, x, num_old_layers, num_extra_layers, features):
-
Define an
MLP(nnx.Module)withnnx.Listofnnx.Linearlayers. First layer mapsx.shape[-1] -> features; the rest mapfeatures -> features. -
Build
smallwithnum_old_layerslayers usingnnx.Rngs(int(seed)). -
Build
bigwithnum_old_layers + num_extra_layerslayers usingnnx.Rngs(int(seed) + 1000)— different seed, so the new params start different. -
For
i in range(num_old_layers): copykernel.valueandbias.valuefromsmall.layers[i]tobig.layers[i]. -
Return
big(x).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
num_old_layers,num_extra_layers,features: ints (passed as floats).
Output: 1-D (features,) — the big model’s output after the warm-start copy.
Hints
Sign in to attempt this problem and view the solution.