medium primitives

NNX Surgery Add Layer

Why this matters

Layer expansion / warm-starting is one of the standard transfer learning moves: you trained a small model, you want a deeper one, and you don’t want to throw away the smaller checkpoint. So you initialize the deeper model and copy the smaller model’s weights into the first N layers — the new layers stay at their fresh random init.

Why not just train from scratch? Two reasons:

  1. Compute economics. A few hundred GPU-hours of backbone training carries over for free.
  2. Optimization geometry. A randomly-initialized large model has different loss landscape than a partially-pretrained one; starting partially trained tends to converge faster on downstream data.

Stacking-up tricks like this are how scaling-law training is done in practice — train at small scale, expand, continue training.

In nnx the operation is just attribute assignment, layer by layer.

The recipe

small = MLP(in_features=D, num_layers=L_old, features=F, rngs=...)
big = MLP(in_features=D, num_layers=L_old + L_extra, features=F, rngs=...)

# Copy first L_old layers from small into big.
for i in range(L_old):
    big.layers[i].kernel.value = small.layers[i].kernel.value
    big.layers[i].bias.value = small.layers[i].bias.value

The remaining L_extra layers in big retain their fresh init. No unfreeze, no path strings — just direct attribute writes.

Why use a different seed for the big model

If both models are built with the same seed, the first L_old layers of the big model will already happen to share param VALUES with the small model — so the copy is a no-op. The test would pass accidentally.

Use a different Rngs (e.g., nnx.Rngs(seed + 1000)) for the big model. Now the copy is observable: without it, the big model’s first L_old layers would have random init values; with it, they match the small model. The downstream layers (L_old onward) keep the big model’s fresh random init either way.

What about activations?

For a real warm-start you’d want the function of the small model to be preserved at the start of training: big(x) ≈ small(x). With a plain layer-copy and random new layers, that’s not quite true — the random tail layers still scramble the output.

Common tricks to recover function-preservation:

  • Initialize new layers as identity (I for kernel, zero for bias).
  • Skip-init: insert each new layer as a residual skip with a zero-initialized branch — x + 0 = x.

Both are cheap to add once you know the surgery primitives. This problem just covers the param copy itself; the function-preservation tricks are good follow-ups for self-study.

API recap: nnx.List indexing

nnx.List (pos 43, 92) supports plain Python indexing. Both these are valid:

big.layers[0].kernel.value = small.layers[0].kernel.value   # write
print(big.layers[2].bias.value.shape)                       # read

len(big.layers) works too. Iteration with for layer in big.layers works. It behaves like a list because it is a list, just registered as a data pytree.

Common pitfalls

  • Same seed for both models. Then the first L_old layers match by accident; the copy isn’t observable.
  • Off-by-one on the loop. for i in range(L_old) — note L_old, not L_old - 1 or L_old + 1.
  • Forgetting the bias. kernel is the matrix, bias is the offset. Both need copying for the layer to behave identically. Half-copy means the bias is whatever the big model’s init produced (probably zero, but conceptually wrong).
  • Cast to int. num_old_layers, num_extra_layers, features arrive as floats. Cast them with int() before using them in range(...) or as dimensions.
  • Plain Python list, no nnx.List. Same trap as pos 92 — the MLP layers won’t show up in the state tree.

Problem

Write add_layer_warm_start(seed, x, num_old_layers, num_extra_layers, features):

  1. Define an MLP(nnx.Module) with nnx.List of nnx.Linear layers. First layer maps x.shape[-1] -> features; the rest map features -> features.
  2. Build small with num_old_layers layers using nnx.Rngs(int(seed)).
  3. Build big with num_old_layers + num_extra_layers layers using nnx.Rngs(int(seed) + 1000) — different seed, so the new params start different.
  4. For i in range(num_old_layers): copy kernel.value and bias.value from small.layers[i] to big.layers[i].
  5. Return big(x).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • num_old_layers, num_extra_layers, features: ints (passed as floats).

Output: 1-D (features,) — the big model’s output after the warm-start copy.

Hints

flax nnx surgery warm-start transfer-learning

Sign in to attempt this problem and view the solution.