We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Surgery Replace
Why this matters
“Model surgery” is the umbrella term for any post-init manipulation of a model’s parameters: replacing one layer’s weights with values from another model, zeroing a kernel for ablation, swapping a head for transfer learning, patching in pre-trained embeddings.
In Linen this is genuinely awkward. Params live in a frozen
nested dict that you spread-update with the unfreeze/copy_dict/
freeze dance — see Linen pos 93 (flax-surgery-replace). The
cognitive overhead is real: you must remember the dict path for the
weight you want, and you can’t accidentally mutate the original.
In nnx the same surgery is a single attribute assignment:
model.kernel.value = new_value
That’s it. The nnx.Param wrapper exposes a writable .value
attribute. There’s no nested-dict spread, no unfreeze, no path
string. Just plain Python attribute assignment — exactly the way
you’d modify any object.
API: nnx.Param.value
Every nnx.Param is a thin wrapper around a JAX array. It supports:
p = nnx.Param(jnp.zeros((3, 4)))
p.value # the underlying (3, 4) array
p.value = jnp.ones((3, 4)) # WRITE: replace the array in place
Reads from the wrapper auto-unwrap for arithmetic
(p + 1 works), but explicit .value is needed when you want
.shape, .T, or to assign a new array.
For an nnx.Linear, the wrappers live at:
model.kernel.value # (in_features, out_features)
model.bias.value # (out_features,)
Both are mutable — you can freely overwrite either.
Worked example
model = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
print(model.kernel.value.shape) # (3, 4)
# Replace every kernel value with 0.5:
model.kernel.value = jnp.full_like(model.kernel.value, 0.5)
x = jnp.array([1.0, 2.0, 3.0])
print(model(x)) # x @ kernel + bias = 6 * 0.5 + bias
jnp.full_like(arr, v) returns a new array with the same shape and
dtype as arr, filled with v. It’s the right idiom for
“replacement that keeps shape and dtype consistent” — never an
accidental dtype promotion.
Compare with Linen
The Linen equivalent (pos 93):
new_params = {
"params": {
"kernel": jnp.full_like(params["params"]["kernel"], v),
"bias": params["params"]["bias"], # carry over
}
}
out = model.apply(new_params, x)
You must explicitly carry over every other param, name the dict path
“params/kernel”, and round-trip through model.apply. Three places
where you can typo and silently produce wrong results.
The nnx version is one assignment. That’s the whole point of the rewrite: surgery becomes a Python-native operation rather than a pytree-manipulation puzzle.
Common pitfalls
-
Forgetting
.value.model.kernel = jnp.full(...)REPLACES the wrapper itself with a raw array, breaking the param tracking. You wantmodel.kernel.value = .... -
Wrong shape.
jnp.full_like(model.kernel.value, v)keeps shape and dtype.jnp.full((4, 3), v)if dimensions are flipped produces a shape-incompatible array — the next forward call crashes. -
Dtype mismatch. Mixing float32 and float64 arrays silently
promotes everything to float64 with a perf hit.
full_likepreserves dtype. -
Treating the assignment as functional. It mutates the model
in place. After the assignment,
modelis the patched object; there’s no “old model” to compare against unless you saved a copy first.
Problem
Write replace_kernel(seed, x, features, replacement_value):
-
Build
model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). -
Replace the kernel:
model.kernel.value = jnp.full_like(model.kernel.value, float(replacement_value)). Bias stays at its init (zeros). -
Apply:
return model(x).
Output is (features,). Each entry equals sum(x) * replacement_value + bias, but bias is initialized to zeros — so each entry equals
x.sum() * replacement_value.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float). -
replacement_value: float.
Output: 1-D (features,).
Hints
Sign in to attempt this problem and view the solution.