medium primitives

Param Surgery — Kernel Replace

Why this matters

Once you’ve trained a model, the params dict is just data — a nested dict of arrays. Sometimes you want to surgically swap one parameter without retraining: porting weights from PyTorch, swapping in a quantized version of a single layer, running ablation studies where you replace a kernel with zeros or noise to see how the rest of the network responds, debugging a layer in isolation by feeding it a known-good weight matrix.

Flax’s params are plain (frozen) dicts, so this is dramatically easier than in PyTorch — no state_dict translation, no load_state_dict strict-mode dance. Just walk the dict.

The shape of params

After params = model.init(rng, x):

{
  "params": {
    "Dense_0": {
      "kernel": jnp.ndarray,    # shape (in, out)
      "bias":   jnp.ndarray,    # shape (out,)
    },
    "Dense_1": {...},
    ...
  }
}

The names Dense_0, Dense_1, … come from Flax’s auto-naming. For a single nn.Dense (not inside a nn.compact __call__), model.init(rng, x) produces just {"params": {"kernel": ..., "bias": ...}} — no Dense_0 prefix, because the model is the Dense.

Two ways to surgically replace

1. Direct dict reconstruction (the explicit way)

params = model.init(rng, x)
old_kernel = params["params"]["kernel"]
new_kernel = jnp.full_like(old_kernel, 0.5)
new_params = {
    "params": {
        **params["params"],
        "kernel": new_kernel,
    }
}
out = model.apply(new_params, x)

Spreading **params["params"] keeps the bias (and any other params) intact; only kernel is replaced. This is what we’ll do here.

2. flax.core.unfreeze + mutate (older Flax)

from flax.core import freeze, unfreeze
p = unfreeze(params)
p["params"]["kernel"] = new_kernel
new_params = freeze(p)

Modern Flax (>= 0.7) returns plain dicts, so unfreeze is a no-op in most code paths — but the pattern still appears in older examples.

3. jax.tree_util.tree_map_with_path (the surgical way)

def replace_kernel(path, leaf):
    if any("kernel" in str(p) for p in path):
        return jnp.full_like(leaf, 0.5)
    return leaf

new_params = jax.tree_util.tree_map_with_path(replace_kernel, params)

Useful when you want to replace every kernel in a deep model without enumerating layer names.

Why jnp.full_like(kernel, value) and not just value?

model.apply expects each param to have the same shape as the initialized version — feeding a scalar where the kernel expects (in, out) raises a shape mismatch at trace time. jnp.full_like preserves the shape and dtype but fills with a constant.

What the function computes

With every kernel entry equal to replacement_value and the bias untouched, the layer output is:

y[n, j] = (sum over i: x[n, i] * replacement_value) + bias[j]
        = replacement_value * x[n, :].sum() + bias[j]

So every output column for row n is shifted by the same scalar replacement_value * x[n, :].sum() — useful sanity check during development.

Common pitfalls

  • Mutating the params dict: Flax params are nested frozen dicts (in modern Flax they’re plain dicts but treat them as immutable). Construct a new dict via spread/replace rather than params["params"]["kernel"] = ....
  • Shape mismatch: jnp.full(kernel.shape, value) works too, but jnp.full_like(kernel, value) also gets the dtype right.
  • Replacing more than you meant: a typo like params["params"] = new_kernel replaces the entire params subtree with a single array; init will rebuild it but apply will silently take the wrong shape.

Problem

Init nn.Dense(features) on (seed, x), then surgically replace its kernel with jnp.full_like(kernel, replacement_value). The bias should remain at its initialized value (zeros — nn.Dense‘s default bias init).

Apply the model to x with the modified params and return the output.

For a (N, in_dim) input, the output is (N, features).

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, in_dim).
  • features: float (cast to int) — Dense output dimension.
  • replacement_value: float — value to fill the kernel with.

Output: 2-D (N, features) array.

Hints

flax surgery params

Sign in to attempt this problem and view the solution.