We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Surgery Zero Layer
Why this matters
Ablation by zeroing is a standard interpretability move: take a trained model, zero out one component, see what the rest can still do. Did the layer matter? If yes, output collapses. If no — you have evidence that the layer was redundant.
The naive way is to delete the layer from the architecture. But then you’ve changed shapes, broken downstream connections, and possibly invalidated a checkpoint. The neat trick is to keep the layer in place but zero its kernel: shapes are unchanged, the layer still exists, but its contribution is exactly zero.
Combined with nnx.Linear‘s default zero-initialized bias, the
result is 0 * x + 0 = 0. The downstream computation receives all
zeros from this layer.
API recap: nnx.List for sub-module sequences
Plain Python lists confuse nnx 0.12.6 (it treats them as static and silently drops the params). The fix:
self.layers = nnx.List([
nnx.Linear(in_dim, features, rngs=rngs)
for _ in range(num_layers)
])
nnx.List is a data container — split/merge sees its contents, and
standard Python indexing/iteration works.
Surgery on a list element
Once you have nnx.List, getting at the last layer’s kernel is a
plain Python expression:
last = model.layers[num_layers - 1]
last.kernel.value = jnp.zeros_like(last.kernel.value)
Notice we keep the bias untouched — last.bias.value still equals
its init (zeros, by default for nnx.Linear). After the surgery:
output = last(prev_layer_out)
= prev @ 0 + bias
= 0 + 0 = 0
The model still runs, still has the same shapes, still has all its parameters in the state tree — but produces all-zeros output.
Worked example
model = MLP(in_features=3, num_layers=2, features=4, rngs=nnx.Rngs(0))
# Layer 0 transforms (3,) -> (4,) with random kernel + zero bias.
# Layer 1 transforms (4,) -> (4,) with random kernel + zero bias.
# Zero layer 1's kernel:
model.layers[1].kernel.value = jnp.zeros_like(model.layers[1].kernel.value)
out = model(jnp.array([1.0, 2.0, 3.0]))
# Layer 0: random output (4,)
# Layer 1: zeros @ random_out + zeros = (4,) of zeros
print(out) # [0., 0., 0., 0.]
Why this generalizes
The same one-liner works on any layer: just replace the index.
model.layers[3].kernel.value = ... zeros the fourth layer.
model.layers[0].bias.value = ... zeros the first layer’s bias.
For layered ablation experiments (zero each layer one at a time and
measure output drop), this is the entire surgery code. No
unfreeze, no nested-dict edits, no path strings.
Common pitfalls
-
Plain Python list, no
nnx.List. The state tree comes back empty for the layer params.nnx.split(model)will silently omit them, andoptimizer.updatewon’t update them. Always wrap a list of submodules innnx.List. -
Indexing past the end.
model.layers[num_layers]is one too far; it’s the count, not the last index. Usenum_layers - 1. - Zeroing the bias too. Question is “what does the rest of the net do without this layer’s contribution?” — bias is the layer’s contribution at zero input, and is normally already zero. Don’t touch it; the test only asks about the kernel.
-
Cast to int.
num_layersandfeaturesarrive as floats — cast both withint()before using them inrange()or as integer dimensions.
Problem
Write zero_last_layer_kernel(seed, x, num_layers, features):
-
Build an
MLP(nnx.Module)withnum_layersnnx.Linearlayers wired innnx.List. First layer mapsx.shape[-1] -> features; subsequent layers mapfeatures -> features. -
Zero the LAST layer’s kernel:
last = model.layers[num_layers - 1]; last.kernel.value = jnp.zeros_like(last.kernel.value). -
Apply the model and return the output (length
features).
With bias defaulting to zeros, the output is all-zeros.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
num_layers: int (passed as float). -
features: int (passed as float).
Output: 1-D (features,) array of zeros.
Hints
Sign in to attempt this problem and view the solution.