We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Param Surgery — Zero Last Layer
Why this matters
Sometimes you want to neutralize a layer without removing it from the architecture. Reasons:
- Stochastic depth / layer-dropping at training time: zero out a whole residual block on this batch, keep its slot so other branches still see the same shape.
- Initialization tricks: Fixup, ReZero, and the original identity-init papers all start critical layers at zero so the residual stream sees an identity at step 0.
- Ablation studies: “what if this MLP block contributed nothing?” Just zero its kernel and re-evaluate; the architecture is preserved so all the other shapes/keys still line up.
- Easier than re-architecting: removing a layer changes the param tree structure, which breaks checkpoint loading. Zeroing a kernel keeps the structure identical.
The clean way: zero the kernel, keep the bias
A nn.Dense(features) computes y = x @ kernel + bias. If kernel
is all zeros:
y = x @ 0 + bias = bias
The output becomes the bias, regardless of the input. If the bias
was also zero (Flax’s default), the layer outputs zero — which is
fine in a residual setup (x + zero_block(x) = x, so the layer is
truly “skipped”). If the bias is nonzero, the layer outputs a constant.
For Fixup/ReZero, you specifically want the kernel zeroed but the bias kept so the layer can still learn from gradients flowing back through the bias.
The params layout for an N-layer MLP
class MLP(nn.Module):
num_layers: int
features: int
@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
x = nn.Dense(self.features)(x)
return x
After params = model.init(rng, x), you get:
{
"params": {
"Dense_0": {"kernel": ..., "bias": ...},
"Dense_1": {"kernel": ..., "bias": ...},
...
"Dense_{N-1}": {"kernel": ..., "bias": ...},
}
}
The auto-naming gives Dense_0, Dense_1, … Dense_{N-1}. The
last layer is Dense_{num_layers - 1}.
The surgery
Same idea as the previous “kernel replace” problem, but with two twists:
-
We’re targeting a named sublayer nested one level deeper:
params["params"]["Dense_{N-1}"]["kernel"]. -
We zero (with
jnp.zeros_like) instead of using a constant.
last_name = f"Dense_{num_layers - 1}"
last_layer = params["params"][last_name]
new_kernel = jnp.zeros_like(last_layer["kernel"])
new_params = {
"params": {
**params["params"],
last_name: {
**last_layer, # keeps bias
"kernel": new_kernel, # overrides only the kernel
},
}
}
Two levels of dict spread: the outer keeps siblings (Dense_0, …,
Dense_{N-2}) intact; the inner keeps bias intact within the
targeted layer.
Test layer’s bias is initialized to ones
To make these tests visibly distinguish “kernel zeroed” from “kernel
untouched,” the last Dense in the MLP uses bias_init=nn.initializers.ones
(all the earlier Dense layers use the default zero bias). With the
last kernel zeroed, the model output is exactly the all-ones vector
of the last layer’s bias — a clean signal.
Build the layers like this:
for i in range(num_layers):
is_last = i == num_layers - 1
bias_init = nn.initializers.ones if is_last else nn.initializers.zeros
x = nn.Dense(features, bias_init=bias_init)(x)
In real Fixup/ReZero code there’s no such cosmetic — the bias init matches whatever your scheme prescribes. We’re forcing ones here only so the test outputs aren’t all-zero, which would fail to detect a bug where the wrong layer got zeroed.
Why we don’t just delete the layer
Two reasons:
-
Checkpoint compat. Removing
Dense_{N-1}from the architecture changes the param tree, breaking older checkpoints’ restore. Zeroing the kernel keeps everything identical on disk. - Gradient flow. A removed layer doesn’t get gradients. A zero-kernel layer still has trainable parameters (kernel and bias) — gradients flow into them and they can wake back up during fine-tuning.
Common pitfalls
-
Off-by-one on layer name: it’s
Dense_{N-1}for the last layer, notDense_{N}. -
Forgetting the inner spread:
last_name: {"kernel": new_kernel}(no**last_layer) drops the bias entirely;applythen raises a missing-key error or, worse in older Flax, silently re-inits. -
Zeroing more than the kernel:
last_layer = {"kernel": 0, "bias": 0}wipes the bias too. The spec says keep the bias. -
Using
jnp.zeros(shape)instead ofjnp.zeros_like: works only if you correctly read the kernel’s shape;zeros_likeis safer and preserves dtype.
Problem
Build an N-layer MLP where every layer is nn.Dense(features).
The last layer uses bias_init=nn.initializers.ones; all earlier
layers use the default zero bias. Init the model on (seed, x),
then ZERO the last layer’s kernel while keeping its bias. Apply
and return the output.
The result is (N_batch, features) of all-ones, because the last
layer’s bias is ones and its kernel is zero.
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, in_dim). -
num_layers: float (cast to int) — number of Dense layers. -
features: float (cast to int) — output features per Dense.
Output: 2-D (N, features) array (all ones).
Hints
Sign in to attempt this problem and view the solution.