medium primitives

NNX Port: Linen MLP → NNX MLP

Why this matters

The single-Dense port (previous problem) is the building block. Real models have many layers, and you have to keep track of which Linen-side weight maps to which nnx-side attribute.

Linen and nnx use different naming schemes:

  • Linen auto-names submodules in @nn.compact as Dense_0, Dense_1, … in declaration order. Their params live under linen_params["params"]["Dense_0"], ["Dense_1"], …
  • nnx uses the attribute names you assigned: self.fc1, self.fc2, …

Porting a multi-layer model is a translation table from Dense_<i> to fc<j>. Build it once, follow it carefully, you’re done.

The two architectures

class LinenMLP(nn.Module):
    hidden: int
    out: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden)(x)   # Dense_0
        x = jax.nn.relu(x)
        x = nn.Dense(features=self.out)(x)      # Dense_1
        return x

class NNXMLP(nnx.Module):
    def __init__(self, in_features, hidden, out, rngs):
        self.fc1 = nnx.Linear(in_features, hidden, rngs=rngs)
        self.fc2 = nnx.Linear(hidden, out, rngs=rngs)
    def __call__(self, x):
        x = self.fc1(x)
        x = jax.nn.relu(x)
        x = self.fc2(x)
        return x

Same math, same dims, different frame.

The port

p = linen_params["params"]            # {"Dense_0": ..., "Dense_1": ...}

nnx_mlp.fc1.kernel.value = p["Dense_0"]["kernel"]
nnx_mlp.fc1.bias.value   = p["Dense_0"]["bias"]
nnx_mlp.fc2.kernel.value = p["Dense_1"]["kernel"]
nnx_mlp.fc2.bias.value   = p["Dense_1"]["bias"]

Now nnx_mlp(x) == linen_mlp.apply(linen_params, x).

Why same seed gives different params (briefly)

You might be tempted to assume nnx.Rngs(seed) and jax.random.PRNGKey(seed) produce identical kernels for matching layers. They don’t — Linen’s @nn.compact decoration, the order in which children consume RNG splits, and the wrapper paths all differ from nnx’s direct construction. The kernels are statistically similar but numerically distinct.

The whole point of porting is to side-step that mismatch: don’t try to reproduce Linen’s init in nnx. Just copy the arrays.

A scalable pattern

For larger models — Transformer, ResNet, etc. — write a small porting function:

def port(linen_params, nnx_model, mapping):
    for linen_path, nnx_path in mapping.items():
        arr = lookup(linen_params, linen_path)   # tuple of dict keys
        assign(nnx_model, nnx_path, arr)         # tuple of attribute names

Mapping is the only project-specific piece. The rest is mechanical.

Common pitfalls

  • Wrong Dense_<i> index. Counts from 0, increments by declaration order in @nn.compact. If you reorder declarations, indices shift.
  • Skipped activation in one side but not the other. A Linen @nn.compact body with implicit ReLU between Dense calls must be mirrored exactly in the nnx forward.
  • Mismatched dims at port time. Both nnx.Linear(in_f, out_f, ...) args must match the corresponding Linen params shape. If the Linen Dense_0 kernel is (in_features, hidden), your nnx fc1 must be (in_features, hidden).
  • Dropping bias. If one side disables bias (use_bias=False) but the other doesn’t, params["..."]["bias"] doesn’t exist on the Linen side or the assignment errors on the nnx side.

Problem

Define LinenMLP(hidden, out) and NNXMLP(in_features, hidden, out, rngs) exactly as shown above. Then write port_mlp(seed, x, hidden, out):

  1. Build & init linen_mlp = LinenMLP(...). Compute linen_out.
  2. Build nnx_mlp = NNXMLP(...) with nnx.Rngs(int(seed) + 1) (init values will be overwritten).
  3. Port the four arrays:
    • nnx_mlp.fc1.kernel.value = linen_params["params"]["Dense_0"]["kernel"].
    • nnx_mlp.fc1.bias.value = linen_params["params"]["Dense_0"]["bias"].
    • nnx_mlp.fc2.kernel.value = linen_params["params"]["Dense_1"]["kernel"].
    • nnx_mlp.fc2.bias.value = linen_params["params"]["Dense_1"]["bias"].
  4. Compute nnx_out. Return jnp.array([float(nnx_out.sum()), float(linen_out.sum())]). The two sums must be equal.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden, out: ints (passed as floats).

Output: length-2 array [s, s] (sums equal).

Hints

flax nnx port weights interop linen mlp

Sign in to attempt this problem and view the solution.