medium primitives

NNX Surgery Replace

Why this matters

“Model surgery” is the umbrella term for any post-init manipulation of a model’s parameters: replacing one layer’s weights with values from another model, zeroing a kernel for ablation, swapping a head for transfer learning, patching in pre-trained embeddings.

In Linen this is genuinely awkward. Params live in a frozen nested dict that you spread-update with the unfreeze/copy_dict/ freeze dance — see Linen pos 93 (flax-surgery-replace). The cognitive overhead is real: you must remember the dict path for the weight you want, and you can’t accidentally mutate the original.

In nnx the same surgery is a single attribute assignment:

model.kernel.value = new_value

That’s it. The nnx.Param wrapper exposes a writable .value attribute. There’s no nested-dict spread, no unfreeze, no path string. Just plain Python attribute assignment — exactly the way you’d modify any object.

API: nnx.Param.value

Every nnx.Param is a thin wrapper around a JAX array. It supports:

p = nnx.Param(jnp.zeros((3, 4)))
p.value                  # the underlying (3, 4) array
p.value = jnp.ones((3, 4))   # WRITE: replace the array in place

Reads from the wrapper auto-unwrap for arithmetic (p + 1 works), but explicit .value is needed when you want .shape, .T, or to assign a new array.

For an nnx.Linear, the wrappers live at:

model.kernel.value      # (in_features, out_features)
model.bias.value        # (out_features,)

Both are mutable — you can freely overwrite either.

Worked example

model = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
print(model.kernel.value.shape)        # (3, 4)

# Replace every kernel value with 0.5:
model.kernel.value = jnp.full_like(model.kernel.value, 0.5)

x = jnp.array([1.0, 2.0, 3.0])
print(model(x))                        # x @ kernel + bias = 6 * 0.5 + bias

jnp.full_like(arr, v) returns a new array with the same shape and dtype as arr, filled with v. It’s the right idiom for “replacement that keeps shape and dtype consistent” — never an accidental dtype promotion.

Compare with Linen

The Linen equivalent (pos 93):

new_params = {
    "params": {
        "kernel": jnp.full_like(params["params"]["kernel"], v),
        "bias": params["params"]["bias"],     # carry over
    }
}
out = model.apply(new_params, x)

You must explicitly carry over every other param, name the dict path “params/kernel”, and round-trip through model.apply. Three places where you can typo and silently produce wrong results.

The nnx version is one assignment. That’s the whole point of the rewrite: surgery becomes a Python-native operation rather than a pytree-manipulation puzzle.

Common pitfalls

  • Forgetting .value. model.kernel = jnp.full(...) REPLACES the wrapper itself with a raw array, breaking the param tracking. You want model.kernel.value = ....
  • Wrong shape. jnp.full_like(model.kernel.value, v) keeps shape and dtype. jnp.full((4, 3), v) if dimensions are flipped produces a shape-incompatible array — the next forward call crashes.
  • Dtype mismatch. Mixing float32 and float64 arrays silently promotes everything to float64 with a perf hit. full_like preserves dtype.
  • Treating the assignment as functional. It mutates the model in place. After the assignment, model is the patched object; there’s no “old model” to compare against unless you saved a copy first.

Problem

Write replace_kernel(seed, x, features, replacement_value):

  1. Build model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))).
  2. Replace the kernel: model.kernel.value = jnp.full_like(model.kernel.value, float(replacement_value)). Bias stays at its init (zeros).
  3. Apply: return model(x).

Output is (features,). Each entry equals sum(x) * replacement_value + bias, but bias is initialized to zeros — so each entry equals x.sum() * replacement_value.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).
  • replacement_value: float.

Output: 1-D (features,).

Hints

flax nnx surgery param-replacement

Sign in to attempt this problem and view the solution.