We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Param Surgery — Kernel Replace
Why this matters
Once you’ve trained a model, the params dict is just data — a nested dict of arrays. Sometimes you want to surgically swap one parameter without retraining: porting weights from PyTorch, swapping in a quantized version of a single layer, running ablation studies where you replace a kernel with zeros or noise to see how the rest of the network responds, debugging a layer in isolation by feeding it a known-good weight matrix.
Flax’s params are plain (frozen) dicts, so this is dramatically
easier than in PyTorch — no state_dict translation, no
load_state_dict strict-mode dance. Just walk the dict.
The shape of params
After params = model.init(rng, x):
{
"params": {
"Dense_0": {
"kernel": jnp.ndarray, # shape (in, out)
"bias": jnp.ndarray, # shape (out,)
},
"Dense_1": {...},
...
}
}
The names Dense_0, Dense_1, … come from Flax’s auto-naming.
For a single nn.Dense (not inside a nn.compact __call__),
model.init(rng, x) produces just {"params": {"kernel": ..., "bias": ...}} —
no Dense_0 prefix, because the model is the Dense.
Two ways to surgically replace
1. Direct dict reconstruction (the explicit way)
params = model.init(rng, x)
old_kernel = params["params"]["kernel"]
new_kernel = jnp.full_like(old_kernel, 0.5)
new_params = {
"params": {
**params["params"],
"kernel": new_kernel,
}
}
out = model.apply(new_params, x)
Spreading **params["params"] keeps the bias (and any other params)
intact; only kernel is replaced. This is what we’ll do here.
2. flax.core.unfreeze + mutate (older Flax)
from flax.core import freeze, unfreeze
p = unfreeze(params)
p["params"]["kernel"] = new_kernel
new_params = freeze(p)
Modern Flax (>= 0.7) returns plain dicts, so unfreeze is a no-op
in most code paths — but the pattern still appears in older
examples.
3. jax.tree_util.tree_map_with_path (the surgical way)
def replace_kernel(path, leaf):
if any("kernel" in str(p) for p in path):
return jnp.full_like(leaf, 0.5)
return leaf
new_params = jax.tree_util.tree_map_with_path(replace_kernel, params)
Useful when you want to replace every kernel in a deep model without enumerating layer names.
Why jnp.full_like(kernel, value) and not just value?
model.apply expects each param to have the same shape as the
initialized version — feeding a scalar where the kernel expects
(in, out) raises a shape mismatch at trace time. jnp.full_like
preserves the shape and dtype but fills with a constant.
What the function computes
With every kernel entry equal to replacement_value and the bias
untouched, the layer output is:
y[n, j] = (sum over i: x[n, i] * replacement_value) + bias[j]
= replacement_value * x[n, :].sum() + bias[j]
So every output column for row n is shifted by the same scalar
replacement_value * x[n, :].sum() — useful sanity check during
development.
Common pitfalls
-
Mutating the params dict: Flax params are nested frozen
dicts (in modern Flax they’re plain dicts but treat them as
immutable). Construct a new dict via spread/
replacerather thanparams["params"]["kernel"] = .... -
Shape mismatch:
jnp.full(kernel.shape, value)works too, butjnp.full_like(kernel, value)also gets the dtype right. -
Replacing more than you meant: a typo like
params["params"] = new_kernelreplaces the entire params subtree with a single array; init will rebuild it but apply will silently take the wrong shape.
Problem
Init nn.Dense(features) on (seed, x), then surgically replace
its kernel with jnp.full_like(kernel, replacement_value). The
bias should remain at its initialized value (zeros — nn.Dense‘s
default bias init).
Apply the model to x with the modified params and return the
output.
For a (N, in_dim) input, the output is (N, features).
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, in_dim). -
features: float (cast to int) — Dense output dimension. -
replacement_value: float — value to fill the kernel with.
Output: 2-D (N, features) array.
Hints
Sign in to attempt this problem and view the solution.