medium primitives

Param Surgery — Zero Last Layer

Why this matters

Sometimes you want to neutralize a layer without removing it from the architecture. Reasons:

  • Stochastic depth / layer-dropping at training time: zero out a whole residual block on this batch, keep its slot so other branches still see the same shape.
  • Initialization tricks: Fixup, ReZero, and the original identity-init papers all start critical layers at zero so the residual stream sees an identity at step 0.
  • Ablation studies: “what if this MLP block contributed nothing?” Just zero its kernel and re-evaluate; the architecture is preserved so all the other shapes/keys still line up.
  • Easier than re-architecting: removing a layer changes the param tree structure, which breaks checkpoint loading. Zeroing a kernel keeps the structure identical.

The clean way: zero the kernel, keep the bias

A nn.Dense(features) computes y = x @ kernel + bias. If kernel is all zeros:

y = x @ 0 + bias = bias

The output becomes the bias, regardless of the input. If the bias was also zero (Flax’s default), the layer outputs zero — which is fine in a residual setup (x + zero_block(x) = x, so the layer is truly “skipped”). If the bias is nonzero, the layer outputs a constant.

For Fixup/ReZero, you specifically want the kernel zeroed but the bias kept so the layer can still learn from gradients flowing back through the bias.

The params layout for an N-layer MLP

class MLP(nn.Module):
    num_layers: int
    features: int

    @nn.compact
    def __call__(self, x):
        for i in range(self.num_layers):
            x = nn.Dense(self.features)(x)
        return x

After params = model.init(rng, x), you get:

{
  "params": {
    "Dense_0": {"kernel": ..., "bias": ...},
    "Dense_1": {"kernel": ..., "bias": ...},
    ...
    "Dense_{N-1}": {"kernel": ..., "bias": ...},
  }
}

The auto-naming gives Dense_0, Dense_1, … Dense_{N-1}. The last layer is Dense_{num_layers - 1}.

The surgery

Same idea as the previous “kernel replace” problem, but with two twists:

  1. We’re targeting a named sublayer nested one level deeper: params["params"]["Dense_{N-1}"]["kernel"].
  2. We zero (with jnp.zeros_like) instead of using a constant.
last_name = f"Dense_{num_layers - 1}"
last_layer = params["params"][last_name]
new_kernel = jnp.zeros_like(last_layer["kernel"])
new_params = {
    "params": {
        **params["params"],
        last_name: {
            **last_layer,         # keeps bias
            "kernel": new_kernel, # overrides only the kernel
        },
    }
}

Two levels of dict spread: the outer keeps siblings (Dense_0, …, Dense_{N-2}) intact; the inner keeps bias intact within the targeted layer.

Test layer’s bias is initialized to ones

To make these tests visibly distinguish “kernel zeroed” from “kernel untouched,” the last Dense in the MLP uses bias_init=nn.initializers.ones (all the earlier Dense layers use the default zero bias). With the last kernel zeroed, the model output is exactly the all-ones vector of the last layer’s bias — a clean signal.

Build the layers like this:

for i in range(num_layers):
    is_last = i == num_layers - 1
    bias_init = nn.initializers.ones if is_last else nn.initializers.zeros
    x = nn.Dense(features, bias_init=bias_init)(x)

In real Fixup/ReZero code there’s no such cosmetic — the bias init matches whatever your scheme prescribes. We’re forcing ones here only so the test outputs aren’t all-zero, which would fail to detect a bug where the wrong layer got zeroed.

Why we don’t just delete the layer

Two reasons:

  1. Checkpoint compat. Removing Dense_{N-1} from the architecture changes the param tree, breaking older checkpoints’ restore. Zeroing the kernel keeps everything identical on disk.
  2. Gradient flow. A removed layer doesn’t get gradients. A zero-kernel layer still has trainable parameters (kernel and bias) — gradients flow into them and they can wake back up during fine-tuning.

Common pitfalls

  • Off-by-one on layer name: it’s Dense_{N-1} for the last layer, not Dense_{N}.
  • Forgetting the inner spread: last_name: {"kernel": new_kernel} (no **last_layer) drops the bias entirely; apply then raises a missing-key error or, worse in older Flax, silently re-inits.
  • Zeroing more than the kernel: last_layer = {"kernel": 0, "bias": 0} wipes the bias too. The spec says keep the bias.
  • Using jnp.zeros(shape) instead of jnp.zeros_like: works only if you correctly read the kernel’s shape; zeros_like is safer and preserves dtype.

Problem

Build an N-layer MLP where every layer is nn.Dense(features). The last layer uses bias_init=nn.initializers.ones; all earlier layers use the default zero bias. Init the model on (seed, x), then ZERO the last layer’s kernel while keeping its bias. Apply and return the output.

The result is (N_batch, features) of all-ones, because the last layer’s bias is ones and its kernel is zero.

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, in_dim).
  • num_layers: float (cast to int) — number of Dense layers.
  • features: float (cast to int) — output features per Dense.

Output: 2-D (N, features) array (all ones).

Hints

flax surgery ablation

Sign in to attempt this problem and view the solution.