medium primitives

NNX Surgery Zero Layer

Why this matters

Ablation by zeroing is a standard interpretability move: take a trained model, zero out one component, see what the rest can still do. Did the layer matter? If yes, output collapses. If no — you have evidence that the layer was redundant.

The naive way is to delete the layer from the architecture. But then you’ve changed shapes, broken downstream connections, and possibly invalidated a checkpoint. The neat trick is to keep the layer in place but zero its kernel: shapes are unchanged, the layer still exists, but its contribution is exactly zero.

Combined with nnx.Linear‘s default zero-initialized bias, the result is 0 * x + 0 = 0. The downstream computation receives all zeros from this layer.

API recap: nnx.List for sub-module sequences

Plain Python lists confuse nnx 0.12.6 (it treats them as static and silently drops the params). The fix:

self.layers = nnx.List([
    nnx.Linear(in_dim, features, rngs=rngs)
    for _ in range(num_layers)
])

nnx.List is a data container — split/merge sees its contents, and standard Python indexing/iteration works.

Surgery on a list element

Once you have nnx.List, getting at the last layer’s kernel is a plain Python expression:

last = model.layers[num_layers - 1]
last.kernel.value = jnp.zeros_like(last.kernel.value)

Notice we keep the bias untouched — last.bias.value still equals its init (zeros, by default for nnx.Linear). After the surgery:

output = last(prev_layer_out)
          = prev @ 0 + bias
          = 0 + 0 = 0

The model still runs, still has the same shapes, still has all its parameters in the state tree — but produces all-zeros output.

Worked example

model = MLP(in_features=3, num_layers=2, features=4, rngs=nnx.Rngs(0))

# Layer 0 transforms (3,) -> (4,) with random kernel + zero bias.
# Layer 1 transforms (4,) -> (4,) with random kernel + zero bias.

# Zero layer 1's kernel:
model.layers[1].kernel.value = jnp.zeros_like(model.layers[1].kernel.value)

out = model(jnp.array([1.0, 2.0, 3.0]))
# Layer 0: random output (4,)
# Layer 1: zeros @ random_out + zeros = (4,) of zeros
print(out)   # [0., 0., 0., 0.]

Why this generalizes

The same one-liner works on any layer: just replace the index. model.layers[3].kernel.value = ... zeros the fourth layer. model.layers[0].bias.value = ... zeros the first layer’s bias.

For layered ablation experiments (zero each layer one at a time and measure output drop), this is the entire surgery code. No unfreeze, no nested-dict edits, no path strings.

Common pitfalls

  • Plain Python list, no nnx.List. The state tree comes back empty for the layer params. nnx.split(model) will silently omit them, and optimizer.update won’t update them. Always wrap a list of submodules in nnx.List.
  • Indexing past the end. model.layers[num_layers] is one too far; it’s the count, not the last index. Use num_layers - 1.
  • Zeroing the bias too. Question is “what does the rest of the net do without this layer’s contribution?” — bias is the layer’s contribution at zero input, and is normally already zero. Don’t touch it; the test only asks about the kernel.
  • Cast to int. num_layers and features arrive as floats — cast both with int() before using them in range() or as integer dimensions.

Problem

Write zero_last_layer_kernel(seed, x, num_layers, features):

  1. Build an MLP(nnx.Module) with num_layers nnx.Linear layers wired in nnx.List. First layer maps x.shape[-1] -> features; subsequent layers map features -> features.
  2. Zero the LAST layer’s kernel: last = model.layers[num_layers - 1]; last.kernel.value = jnp.zeros_like(last.kernel.value).
  3. Apply the model and return the output (length features).

With bias defaulting to zeros, the output is all-zeros.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • num_layers: int (passed as float).
  • features: int (passed as float).

Output: 1-D (features,) array of zeros.

Hints

flax nnx surgery ablation

Sign in to attempt this problem and view the solution.