We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Port: Linen MLP → NNX MLP
Why this matters
The single-Dense port (previous problem) is the building block. Real models have many layers, and you have to keep track of which Linen-side weight maps to which nnx-side attribute.
Linen and nnx use different naming schemes:
-
Linen auto-names submodules in
@nn.compactasDense_0,Dense_1, … in declaration order. Their params live underlinen_params["params"]["Dense_0"],["Dense_1"], … -
nnx uses the attribute names you assigned:
self.fc1,self.fc2, …
Porting a multi-layer model is a translation table from Dense_<i>
to fc<j>. Build it once, follow it carefully, you’re done.
The two architectures
class LinenMLP(nn.Module):
hidden: int
out: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.hidden)(x) # Dense_0
x = jax.nn.relu(x)
x = nn.Dense(features=self.out)(x) # Dense_1
return x
class NNXMLP(nnx.Module):
def __init__(self, in_features, hidden, out, rngs):
self.fc1 = nnx.Linear(in_features, hidden, rngs=rngs)
self.fc2 = nnx.Linear(hidden, out, rngs=rngs)
def __call__(self, x):
x = self.fc1(x)
x = jax.nn.relu(x)
x = self.fc2(x)
return x
Same math, same dims, different frame.
The port
p = linen_params["params"] # {"Dense_0": ..., "Dense_1": ...}
nnx_mlp.fc1.kernel.value = p["Dense_0"]["kernel"]
nnx_mlp.fc1.bias.value = p["Dense_0"]["bias"]
nnx_mlp.fc2.kernel.value = p["Dense_1"]["kernel"]
nnx_mlp.fc2.bias.value = p["Dense_1"]["bias"]
Now nnx_mlp(x) == linen_mlp.apply(linen_params, x).
Why same seed gives different params (briefly)
You might be tempted to assume nnx.Rngs(seed) and
jax.random.PRNGKey(seed) produce identical kernels for matching
layers. They don’t — Linen’s @nn.compact decoration, the order in
which children consume RNG splits, and the wrapper paths all differ
from nnx’s direct construction. The kernels are statistically similar
but numerically distinct.
The whole point of porting is to side-step that mismatch: don’t try to reproduce Linen’s init in nnx. Just copy the arrays.
A scalable pattern
For larger models — Transformer, ResNet, etc. — write a small porting function:
def port(linen_params, nnx_model, mapping):
for linen_path, nnx_path in mapping.items():
arr = lookup(linen_params, linen_path) # tuple of dict keys
assign(nnx_model, nnx_path, arr) # tuple of attribute names
Mapping is the only project-specific piece. The rest is mechanical.
Common pitfalls
-
Wrong
Dense_<i>index. Counts from0, increments by declaration order in@nn.compact. If you reorder declarations, indices shift. -
Skipped activation in one side but not the other. A Linen
@nn.compactbody with implicit ReLU between Dense calls must be mirrored exactly in the nnx forward. -
Mismatched dims at port time. Both
nnx.Linear(in_f, out_f, ...)args must match the corresponding Linen params shape. If the LinenDense_0kernel is(in_features, hidden), your nnxfc1must be(in_features, hidden). -
Dropping bias. If one side disables bias (
use_bias=False) but the other doesn’t,params["..."]["bias"]doesn’t exist on the Linen side or the assignment errors on the nnx side.
Problem
Define LinenMLP(hidden, out) and NNXMLP(in_features, hidden, out, rngs) exactly as shown above. Then write port_mlp(seed, x, hidden, out):
-
Build & init
linen_mlp = LinenMLP(...). Computelinen_out. -
Build
nnx_mlp = NNXMLP(...)withnnx.Rngs(int(seed) + 1)(init values will be overwritten). -
Port the four arrays:
-
nnx_mlp.fc1.kernel.value = linen_params["params"]["Dense_0"]["kernel"]. -
nnx_mlp.fc1.bias.value = linen_params["params"]["Dense_0"]["bias"]. -
nnx_mlp.fc2.kernel.value = linen_params["params"]["Dense_1"]["kernel"]. -
nnx_mlp.fc2.bias.value = linen_params["params"]["Dense_1"]["bias"].
-
-
Compute
nnx_out. Returnjnp.array([float(nnx_out.sum()), float(linen_out.sum())]). The two sums must be equal.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden,out: ints (passed as floats).
Output: length-2 array [s, s] (sums equal).
Hints
Sign in to attempt this problem and view the solution.